1use zeph_llm::any::AnyProvider;
5use zeph_llm::provider::{LlmProvider, Message, MessageMetadata, Role};
6
7use std::sync::Arc;
8
9use crate::embedding_store::{EmbeddingStore, MessageKind, SearchFilter};
10use crate::error::MemoryError;
11use crate::sqlite::SqliteStore;
12use crate::token_counter::TokenCounter;
13use crate::types::{ConversationId, MessageId};
14use crate::vector_store::{FieldCondition, FieldValue, VectorFilter};
15
16const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
17const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
18const CORRECTIONS_COLLECTION: &str = "zeph_corrections";
19
20#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
21pub struct StructuredSummary {
22 pub summary: String,
23 pub key_facts: Vec<String>,
24 pub entities: Vec<String>,
25}
26
27#[derive(Debug)]
28pub struct RecalledMessage {
29 pub message: Message,
30 pub score: f32,
31}
32
33#[derive(Debug, Clone)]
34pub struct Summary {
35 pub id: i64,
36 pub conversation_id: ConversationId,
37 pub content: String,
38 pub first_message_id: MessageId,
39 pub last_message_id: MessageId,
40 pub token_estimate: i64,
41}
42
43#[derive(Debug, Clone)]
44pub struct SessionSummaryResult {
45 pub summary_text: String,
46 pub score: f32,
47 pub conversation_id: ConversationId,
48}
49
50fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
51 if a.len() != b.len() || a.is_empty() {
52 return 0.0;
53 }
54 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
55 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
56 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
57 if norm_a == 0.0 || norm_b == 0.0 {
58 return 0.0;
59 }
60 dot / (norm_a * norm_b)
61}
62
63fn apply_temporal_decay(
64 ranked: &mut [(MessageId, f64)],
65 timestamps: &std::collections::HashMap<MessageId, i64>,
66 half_life_days: u32,
67) {
68 if half_life_days == 0 {
69 return;
70 }
71 let now = std::time::SystemTime::now()
72 .duration_since(std::time::UNIX_EPOCH)
73 .unwrap_or_default()
74 .as_secs()
75 .cast_signed();
76 let lambda = std::f64::consts::LN_2 / f64::from(half_life_days);
77
78 for (msg_id, score) in ranked.iter_mut() {
79 if let Some(&ts) = timestamps.get(msg_id) {
80 #[allow(clippy::cast_precision_loss)]
81 let age_days = (now - ts).max(0) as f64 / 86400.0;
82 *score *= (-lambda * age_days).exp();
83 }
84 }
85}
86
87fn apply_mmr(
88 ranked: &[(MessageId, f64)],
89 vectors: &std::collections::HashMap<MessageId, Vec<f32>>,
90 lambda: f32,
91 limit: usize,
92) -> Vec<(MessageId, f64)> {
93 if ranked.is_empty() || limit == 0 {
94 return Vec::new();
95 }
96
97 let lambda = f64::from(lambda);
98 let mut selected: Vec<(MessageId, f64)> = Vec::with_capacity(limit);
99 let mut remaining: Vec<(MessageId, f64)> = ranked.to_vec();
100
101 while selected.len() < limit && !remaining.is_empty() {
102 let best_idx = if selected.is_empty() {
103 0
105 } else {
106 let mut best = 0usize;
107 let mut best_score = f64::NEG_INFINITY;
108
109 for (i, &(cand_id, relevance)) in remaining.iter().enumerate() {
110 let max_sim = if let Some(cand_vec) = vectors.get(&cand_id) {
111 selected
112 .iter()
113 .filter_map(|(sel_id, _)| vectors.get(sel_id))
114 .map(|sel_vec| f64::from(cosine_similarity(cand_vec, sel_vec)))
115 .fold(f64::NEG_INFINITY, f64::max)
116 } else {
117 0.0
118 };
119 let max_sim = if max_sim == f64::NEG_INFINITY {
120 0.0
121 } else {
122 max_sim
123 };
124 let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim;
125 if mmr_score > best_score {
126 best_score = mmr_score;
127 best = i;
128 }
129 }
130 best
131 };
132
133 selected.push(remaining.remove(best_idx));
134 }
135
136 selected
137}
138
139fn build_summarization_prompt(messages: &[(MessageId, String, String)]) -> String {
140 let mut prompt = String::from(
141 "Summarize the following conversation. Extract key facts, decisions, entities, \
142 and context needed to continue the conversation.\n\n\
143 Respond in JSON with fields: summary (string), key_facts (list of strings), \
144 entities (list of strings).\n\nConversation:\n",
145 );
146
147 for (_, role, content) in messages {
148 prompt.push_str(role);
149 prompt.push_str(": ");
150 prompt.push_str(content);
151 prompt.push('\n');
152 }
153
154 prompt
155}
156
157pub struct SemanticMemory {
158 sqlite: SqliteStore,
159 qdrant: Option<EmbeddingStore>,
160 provider: AnyProvider,
161 embedding_model: String,
162 vector_weight: f64,
163 keyword_weight: f64,
164 temporal_decay_enabled: bool,
165 temporal_decay_half_life_days: u32,
166 mmr_enabled: bool,
167 mmr_lambda: f32,
168 pub token_counter: Arc<TokenCounter>,
169}
170
171impl SemanticMemory {
172 pub async fn new(
180 sqlite_path: &str,
181 qdrant_url: &str,
182 provider: AnyProvider,
183 embedding_model: &str,
184 ) -> Result<Self, MemoryError> {
185 Self::with_weights(sqlite_path, qdrant_url, provider, embedding_model, 0.7, 0.3).await
186 }
187
188 pub async fn with_weights(
194 sqlite_path: &str,
195 qdrant_url: &str,
196 provider: AnyProvider,
197 embedding_model: &str,
198 vector_weight: f64,
199 keyword_weight: f64,
200 ) -> Result<Self, MemoryError> {
201 Self::with_weights_and_pool_size(
202 sqlite_path,
203 qdrant_url,
204 provider,
205 embedding_model,
206 vector_weight,
207 keyword_weight,
208 5,
209 )
210 .await
211 }
212
213 pub async fn with_weights_and_pool_size(
219 sqlite_path: &str,
220 qdrant_url: &str,
221 provider: AnyProvider,
222 embedding_model: &str,
223 vector_weight: f64,
224 keyword_weight: f64,
225 pool_size: u32,
226 ) -> Result<Self, MemoryError> {
227 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
228 let pool = sqlite.pool().clone();
229
230 let qdrant = match EmbeddingStore::new(qdrant_url, pool) {
231 Ok(store) => Some(store),
232 Err(e) => {
233 tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
234 None
235 }
236 };
237
238 Ok(Self {
239 sqlite,
240 qdrant,
241 provider,
242 embedding_model: embedding_model.into(),
243 vector_weight,
244 keyword_weight,
245 temporal_decay_enabled: false,
246 temporal_decay_half_life_days: 30,
247 mmr_enabled: false,
248 mmr_lambda: 0.7,
249 token_counter: Arc::new(TokenCounter::new()),
250 })
251 }
252
253 #[must_use]
255 pub fn with_ranking_options(
256 mut self,
257 temporal_decay_enabled: bool,
258 temporal_decay_half_life_days: u32,
259 mmr_enabled: bool,
260 mmr_lambda: f32,
261 ) -> Self {
262 self.temporal_decay_enabled = temporal_decay_enabled;
263 self.temporal_decay_half_life_days = temporal_decay_half_life_days;
264 self.mmr_enabled = mmr_enabled;
265 self.mmr_lambda = mmr_lambda;
266 self
267 }
268
269 pub async fn with_sqlite_backend(
275 sqlite_path: &str,
276 provider: AnyProvider,
277 embedding_model: &str,
278 vector_weight: f64,
279 keyword_weight: f64,
280 ) -> Result<Self, MemoryError> {
281 Self::with_sqlite_backend_and_pool_size(
282 sqlite_path,
283 provider,
284 embedding_model,
285 vector_weight,
286 keyword_weight,
287 5,
288 )
289 .await
290 }
291
292 pub async fn with_sqlite_backend_and_pool_size(
298 sqlite_path: &str,
299 provider: AnyProvider,
300 embedding_model: &str,
301 vector_weight: f64,
302 keyword_weight: f64,
303 pool_size: u32,
304 ) -> Result<Self, MemoryError> {
305 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
306 let pool = sqlite.pool().clone();
307 let store = EmbeddingStore::new_sqlite(pool);
308
309 Ok(Self {
310 sqlite,
311 qdrant: Some(store),
312 provider,
313 embedding_model: embedding_model.into(),
314 vector_weight,
315 keyword_weight,
316 temporal_decay_enabled: false,
317 temporal_decay_half_life_days: 30,
318 mmr_enabled: false,
319 mmr_lambda: 0.7,
320 token_counter: Arc::new(TokenCounter::new()),
321 })
322 }
323
324 pub async fn remember(
333 &self,
334 conversation_id: ConversationId,
335 role: &str,
336 content: &str,
337 ) -> Result<MessageId, MemoryError> {
338 let message_id = self
339 .sqlite
340 .save_message(conversation_id, role, content)
341 .await?;
342
343 if let Some(qdrant) = &self.qdrant
344 && self.provider.supports_embeddings()
345 {
346 match self.provider.embed(content).await {
347 Ok(vector) => {
348 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
350 if let Err(e) = qdrant.ensure_collection(vector_size).await {
351 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
352 } else if let Err(e) = qdrant
353 .store(
354 message_id,
355 conversation_id,
356 role,
357 vector,
358 MessageKind::Regular,
359 &self.embedding_model,
360 )
361 .await
362 {
363 tracing::warn!("Failed to store embedding: {e:#}");
364 }
365 }
366 Err(e) => {
367 tracing::warn!("Failed to generate embedding: {e:#}");
368 }
369 }
370 }
371
372 Ok(message_id)
373 }
374
375 pub async fn remember_with_parts(
384 &self,
385 conversation_id: ConversationId,
386 role: &str,
387 content: &str,
388 parts_json: &str,
389 ) -> Result<(MessageId, bool), MemoryError> {
390 let message_id = self
391 .sqlite
392 .save_message_with_parts(conversation_id, role, content, parts_json)
393 .await?;
394
395 let mut embedding_stored = false;
396
397 if let Some(qdrant) = &self.qdrant
398 && self.provider.supports_embeddings()
399 {
400 match self.provider.embed(content).await {
401 Ok(vector) => {
402 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
403 if let Err(e) = qdrant.ensure_collection(vector_size).await {
404 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
405 } else if let Err(e) = qdrant
406 .store(
407 message_id,
408 conversation_id,
409 role,
410 vector,
411 MessageKind::Regular,
412 &self.embedding_model,
413 )
414 .await
415 {
416 tracing::warn!("Failed to store embedding: {e:#}");
417 } else {
418 embedding_stored = true;
419 }
420 }
421 Err(e) => {
422 tracing::warn!("Failed to generate embedding: {e:#}");
423 }
424 }
425 }
426
427 Ok((message_id, embedding_stored))
428 }
429
430 pub async fn save_only(
438 &self,
439 conversation_id: ConversationId,
440 role: &str,
441 content: &str,
442 parts_json: &str,
443 ) -> Result<MessageId, MemoryError> {
444 self.sqlite
445 .save_message_with_parts(conversation_id, role, content, parts_json)
446 .await
447 }
448
449 #[allow(clippy::too_many_lines)]
459 pub async fn recall(
460 &self,
461 query: &str,
462 limit: usize,
463 filter: Option<SearchFilter>,
464 ) -> Result<Vec<RecalledMessage>, MemoryError> {
465 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
466
467 let keyword_results = match self
469 .sqlite
470 .keyword_search(query, limit * 2, conversation_id)
471 .await
472 {
473 Ok(results) => results,
474 Err(e) => {
475 tracing::warn!("FTS5 keyword search failed: {e:#}");
476 Vec::new()
477 }
478 };
479
480 let vector_results = if let Some(qdrant) = &self.qdrant
482 && self.provider.supports_embeddings()
483 {
484 let query_vector = self.provider.embed(query).await?;
485 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
486 qdrant.ensure_collection(vector_size).await?;
487 qdrant.search(&query_vector, limit * 2, filter).await?
488 } else {
489 Vec::new()
490 };
491
492 let mut scores: std::collections::HashMap<MessageId, f64> =
494 std::collections::HashMap::new();
495
496 if !vector_results.is_empty() {
497 let max_vs = vector_results
498 .iter()
499 .map(|r| r.score)
500 .fold(f32::NEG_INFINITY, f32::max);
501 let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
502 for r in &vector_results {
503 let normalized = f64::from(r.score / norm);
504 *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
505 }
506 }
507
508 if !keyword_results.is_empty() {
509 let max_ks = keyword_results
510 .iter()
511 .map(|r| r.1)
512 .fold(f64::NEG_INFINITY, f64::max);
513 let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
514 for &(msg_id, score) in &keyword_results {
515 let normalized = score / norm;
516 *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
517 }
518 }
519
520 if scores.is_empty() {
521 return Ok(Vec::new());
522 }
523
524 let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
526 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
527
528 if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
530 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
531 match self.sqlite.message_timestamps(&ids).await {
532 Ok(timestamps) => {
533 apply_temporal_decay(
534 &mut ranked,
535 ×tamps,
536 self.temporal_decay_half_life_days,
537 );
538 ranked
539 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
540 }
541 Err(e) => {
542 tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
543 }
544 }
545 }
546
547 if self.mmr_enabled && !vector_results.is_empty() {
549 if let Some(qdrant) = &self.qdrant {
550 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
551 match qdrant.get_vectors(&ids).await {
552 Ok(vec_map) if !vec_map.is_empty() => {
553 ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
554 }
555 Ok(_) => {
556 ranked.truncate(limit);
557 }
558 Err(e) => {
559 tracing::warn!("MMR: failed to fetch vectors: {e:#}");
560 ranked.truncate(limit);
561 }
562 }
563 } else {
564 ranked.truncate(limit);
565 }
566 } else {
567 ranked.truncate(limit);
568 }
569
570 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
571 let messages = self.sqlite.messages_by_ids(&ids).await?;
572 let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
573
574 let recalled = ranked
575 .iter()
576 .filter_map(|(msg_id, score)| {
577 msg_map.get(msg_id).map(|msg| RecalledMessage {
578 message: msg.clone(),
579 #[expect(clippy::cast_possible_truncation)]
580 score: *score as f32,
581 })
582 })
583 .collect();
584
585 Ok(recalled)
586 }
587
588 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
594 match &self.qdrant {
595 Some(qdrant) => qdrant.has_embedding(message_id).await,
596 None => Ok(false),
597 }
598 }
599
600 pub async fn embed_missing(&self) -> Result<usize, MemoryError> {
609 let Some(qdrant) = &self.qdrant else {
610 return Ok(0);
611 };
612 if !self.provider.supports_embeddings() {
613 return Ok(0);
614 }
615
616 let unembedded = self.sqlite.unembedded_message_ids(Some(1000)).await?;
617
618 if unembedded.is_empty() {
619 return Ok(0);
620 }
621
622 let probe = self.provider.embed("probe").await?;
623 let vector_size = u64::try_from(probe.len())?;
624 qdrant.ensure_collection(vector_size).await?;
625
626 let mut count = 0;
627 for (msg_id, conversation_id, role, content) in &unembedded {
628 match self.provider.embed(content).await {
629 Ok(vector) => {
630 if let Err(e) = qdrant
631 .store(
632 *msg_id,
633 *conversation_id,
634 role,
635 vector,
636 MessageKind::Regular,
637 &self.embedding_model,
638 )
639 .await
640 {
641 tracing::warn!("Failed to store embedding for msg {msg_id}: {e:#}");
642 continue;
643 }
644 count += 1;
645 }
646 Err(e) => {
647 tracing::warn!("Failed to embed msg {msg_id}: {e:#}");
648 }
649 }
650 }
651
652 tracing::info!("Embedded {count}/{} missing messages", unembedded.len());
653 Ok(count)
654 }
655
656 pub async fn store_session_summary(
662 &self,
663 conversation_id: ConversationId,
664 summary_text: &str,
665 ) -> Result<(), MemoryError> {
666 let Some(qdrant) = &self.qdrant else {
667 return Ok(());
668 };
669 if !self.provider.supports_embeddings() {
670 return Ok(());
671 }
672
673 let vector = self.provider.embed(summary_text).await?;
674 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
675 qdrant
676 .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
677 .await?;
678
679 let payload = serde_json::json!({
680 "conversation_id": conversation_id.0,
681 "summary_text": summary_text,
682 });
683
684 qdrant
685 .store_to_collection(SESSION_SUMMARIES_COLLECTION, payload, vector)
686 .await?;
687
688 tracing::debug!(
689 conversation_id = conversation_id.0,
690 "stored session summary"
691 );
692 Ok(())
693 }
694
695 pub async fn search_session_summaries(
701 &self,
702 query: &str,
703 limit: usize,
704 exclude_conversation_id: Option<ConversationId>,
705 ) -> Result<Vec<SessionSummaryResult>, MemoryError> {
706 let Some(qdrant) = &self.qdrant else {
707 return Ok(Vec::new());
708 };
709 if !self.provider.supports_embeddings() {
710 return Ok(Vec::new());
711 }
712
713 let vector = self.provider.embed(query).await?;
714 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
715 qdrant
716 .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
717 .await?;
718
719 let filter = exclude_conversation_id.map(|cid| VectorFilter {
720 must: vec![],
721 must_not: vec![FieldCondition {
722 field: "conversation_id".into(),
723 value: FieldValue::Integer(cid.0),
724 }],
725 });
726
727 let points = qdrant
728 .search_collection(SESSION_SUMMARIES_COLLECTION, &vector, limit, filter)
729 .await?;
730
731 let results = points
732 .into_iter()
733 .filter_map(|point| {
734 let summary_text = point.payload.get("summary_text")?.as_str()?.to_owned();
735 let conversation_id =
736 ConversationId(point.payload.get("conversation_id")?.as_i64()?);
737 Some(SessionSummaryResult {
738 summary_text,
739 score: point.score,
740 conversation_id,
741 })
742 })
743 .collect();
744
745 Ok(results)
746 }
747
748 #[must_use]
750 pub fn sqlite(&self) -> &SqliteStore {
751 &self.sqlite
752 }
753
754 pub async fn is_vector_store_connected(&self) -> bool {
759 match self.qdrant.as_ref() {
760 Some(store) => store.health_check().await,
761 None => false,
762 }
763 }
764
765 #[must_use]
767 pub fn has_vector_store(&self) -> bool {
768 self.qdrant.is_some()
769 }
770
771 pub async fn message_count(&self, conversation_id: ConversationId) -> Result<i64, MemoryError> {
777 self.sqlite.count_messages(conversation_id).await
778 }
779
780 pub async fn unsummarized_message_count(
786 &self,
787 conversation_id: ConversationId,
788 ) -> Result<i64, MemoryError> {
789 let after_id = self
790 .sqlite
791 .latest_summary_last_message_id(conversation_id)
792 .await?
793 .unwrap_or(MessageId(0));
794 self.sqlite
795 .count_messages_after(conversation_id, after_id)
796 .await
797 }
798
799 pub async fn load_summaries(
805 &self,
806 conversation_id: ConversationId,
807 ) -> Result<Vec<Summary>, MemoryError> {
808 let rows = self.sqlite.load_summaries(conversation_id).await?;
809 let summaries = rows
810 .into_iter()
811 .map(
812 |(
813 id,
814 conversation_id,
815 content,
816 first_message_id,
817 last_message_id,
818 token_estimate,
819 )| {
820 Summary {
821 id,
822 conversation_id,
823 content,
824 first_message_id,
825 last_message_id,
826 token_estimate,
827 }
828 },
829 )
830 .collect();
831 Ok(summaries)
832 }
833
834 pub async fn summarize(
842 &self,
843 conversation_id: ConversationId,
844 message_count: usize,
845 ) -> Result<Option<i64>, MemoryError> {
846 let total = self.sqlite.count_messages(conversation_id).await?;
847
848 if total <= i64::try_from(message_count)? {
849 return Ok(None);
850 }
851
852 let after_id = self
853 .sqlite
854 .latest_summary_last_message_id(conversation_id)
855 .await?
856 .unwrap_or(MessageId(0));
857
858 let messages = self
859 .sqlite
860 .load_messages_range(conversation_id, after_id, message_count)
861 .await?;
862
863 if messages.is_empty() {
864 return Ok(None);
865 }
866
867 let prompt = build_summarization_prompt(&messages);
868 let chat_messages = vec![Message {
869 role: Role::User,
870 content: prompt,
871 parts: vec![],
872 metadata: MessageMetadata::default(),
873 }];
874
875 let structured = match self
876 .provider
877 .chat_typed_erased::<StructuredSummary>(&chat_messages)
878 .await
879 {
880 Ok(s) => s,
881 Err(e) => {
882 tracing::warn!(
883 "structured summarization failed, falling back to plain text: {e:#}"
884 );
885 let plain = self.provider.chat(&chat_messages).await?;
886 StructuredSummary {
887 summary: plain,
888 key_facts: vec![],
889 entities: vec![],
890 }
891 }
892 };
893 let summary_text = &structured.summary;
894
895 let token_estimate = i64::try_from(self.token_counter.count_tokens(summary_text))?;
896 let first_message_id = messages[0].0;
897 let last_message_id = messages[messages.len() - 1].0;
898
899 let summary_id = self
900 .sqlite
901 .save_summary(
902 conversation_id,
903 summary_text,
904 first_message_id,
905 last_message_id,
906 token_estimate,
907 )
908 .await?;
909
910 if let Some(qdrant) = &self.qdrant
911 && self.provider.supports_embeddings()
912 {
913 match self.provider.embed(summary_text).await {
914 Ok(vector) => {
915 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
917 if let Err(e) = qdrant.ensure_collection(vector_size).await {
918 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
919 } else if let Err(e) = qdrant
920 .store(
921 MessageId(summary_id),
922 conversation_id,
923 "system",
924 vector,
925 MessageKind::Summary,
926 &self.embedding_model,
927 )
928 .await
929 {
930 tracing::warn!("Failed to embed summary: {e:#}");
931 }
932 }
933 Err(e) => {
934 tracing::warn!("Failed to generate summary embedding: {e:#}");
935 }
936 }
937 }
938
939 if !structured.key_facts.is_empty() {
941 self.store_key_facts(conversation_id, summary_id, &structured.key_facts)
942 .await;
943 }
944
945 Ok(Some(summary_id))
946 }
947
948 async fn store_key_facts(
949 &self,
950 conversation_id: ConversationId,
951 source_summary_id: i64,
952 key_facts: &[String],
953 ) {
954 let Some(qdrant) = &self.qdrant else {
955 return;
956 };
957 if !self.provider.supports_embeddings() {
958 return;
959 }
960
961 let Some(first_fact) = key_facts.first() else {
962 return;
963 };
964 let first_vector = match self.provider.embed(first_fact).await {
965 Ok(v) => v,
966 Err(e) => {
967 tracing::warn!("Failed to embed key fact: {e:#}");
968 return;
969 }
970 };
971 let vector_size = u64::try_from(first_vector.len()).unwrap_or(896);
972 if let Err(e) = qdrant
973 .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
974 .await
975 {
976 tracing::warn!("Failed to ensure key_facts collection: {e:#}");
977 return;
978 }
979
980 let first_payload = serde_json::json!({
981 "conversation_id": conversation_id.0,
982 "fact_text": first_fact,
983 "source_summary_id": source_summary_id,
984 });
985 if let Err(e) = qdrant
986 .store_to_collection(KEY_FACTS_COLLECTION, first_payload, first_vector)
987 .await
988 {
989 tracing::warn!("Failed to store key fact: {e:#}");
990 }
991
992 for fact in &key_facts[1..] {
993 match self.provider.embed(fact).await {
994 Ok(vector) => {
995 let payload = serde_json::json!({
996 "conversation_id": conversation_id.0,
997 "fact_text": fact,
998 "source_summary_id": source_summary_id,
999 });
1000 if let Err(e) = qdrant
1001 .store_to_collection(KEY_FACTS_COLLECTION, payload, vector)
1002 .await
1003 {
1004 tracing::warn!("Failed to store key fact: {e:#}");
1005 }
1006 }
1007 Err(e) => {
1008 tracing::warn!("Failed to embed key fact: {e:#}");
1009 }
1010 }
1011 }
1012 }
1013
1014 pub async fn search_key_facts(
1020 &self,
1021 query: &str,
1022 limit: usize,
1023 ) -> Result<Vec<String>, MemoryError> {
1024 let Some(qdrant) = &self.qdrant else {
1025 return Ok(Vec::new());
1026 };
1027 if !self.provider.supports_embeddings() {
1028 return Ok(Vec::new());
1029 }
1030
1031 let vector = self.provider.embed(query).await?;
1032 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
1033 qdrant
1034 .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
1035 .await?;
1036
1037 let points = qdrant
1038 .search_collection(KEY_FACTS_COLLECTION, &vector, limit, None)
1039 .await?;
1040
1041 let facts = points
1042 .into_iter()
1043 .filter_map(|p| p.payload.get("fact_text")?.as_str().map(String::from))
1044 .collect();
1045
1046 Ok(facts)
1047 }
1048
1049 pub async fn search_document_collection(
1059 &self,
1060 collection: &str,
1061 query: &str,
1062 limit: usize,
1063 ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
1064 let Some(qdrant) = &self.qdrant else {
1065 return Ok(Vec::new());
1066 };
1067 if !self.provider.supports_embeddings() {
1068 return Ok(Vec::new());
1069 }
1070 if !qdrant.collection_exists(collection).await? {
1071 return Ok(Vec::new());
1072 }
1073 let vector = self.provider.embed(query).await?;
1074 qdrant
1075 .search_collection(collection, &vector, limit, None)
1076 .await
1077 }
1078
1079 pub async fn store_correction_embedding(
1087 &self,
1088 correction_id: i64,
1089 correction_text: &str,
1090 ) -> Result<(), MemoryError> {
1091 let Some(ref store) = self.qdrant else {
1092 return Ok(());
1093 };
1094 if !self.provider.supports_embeddings() {
1095 return Ok(());
1096 }
1097 let embedding = self
1098 .provider
1099 .embed(correction_text)
1100 .await
1101 .map_err(|e| MemoryError::Other(e.to_string()))?;
1102 let payload = serde_json::json!({ "correction_id": correction_id });
1103 store
1104 .store_to_collection(CORRECTIONS_COLLECTION, payload, embedding)
1105 .await?;
1106 Ok(())
1107 }
1108
1109 pub async fn retrieve_similar_corrections(
1118 &self,
1119 query: &str,
1120 limit: usize,
1121 min_score: f32,
1122 ) -> Result<Vec<crate::sqlite::corrections::UserCorrectionRow>, MemoryError> {
1123 let Some(ref store) = self.qdrant else {
1124 return Ok(vec![]);
1125 };
1126 if !self.provider.supports_embeddings() {
1127 return Ok(vec![]);
1128 }
1129 let embedding = self
1130 .provider
1131 .embed(query)
1132 .await
1133 .map_err(|e| MemoryError::Other(e.to_string()))?;
1134 let scored = store
1135 .search_collection(CORRECTIONS_COLLECTION, &embedding, limit, None)
1136 .await
1137 .unwrap_or_default();
1138
1139 let mut results = Vec::new();
1140 for point in scored {
1141 if point.score < min_score {
1142 continue;
1143 }
1144 if let Some(id_val) = point.payload.get("correction_id")
1145 && let Some(id) = id_val.as_i64()
1146 {
1147 let rows = self.sqlite.load_corrections_for_id(id).await?;
1148 results.extend(rows);
1149 }
1150 }
1151 Ok(results)
1152 }
1153}
1154
1155#[cfg(test)]
1156mod tests {
1157 use zeph_llm::mock::MockProvider;
1158 use zeph_llm::provider::Role;
1159
1160 use super::*;
1161
1162 fn test_provider() -> AnyProvider {
1163 AnyProvider::Mock(MockProvider::default())
1164 }
1165
1166 async fn test_semantic_memory(_supports_embeddings: bool) -> SemanticMemory {
1167 let provider = test_provider();
1168 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1169
1170 SemanticMemory {
1171 sqlite,
1172 qdrant: None,
1173 provider,
1174 embedding_model: "test-model".into(),
1175 vector_weight: 0.7,
1176 keyword_weight: 0.3,
1177 temporal_decay_enabled: false,
1178 temporal_decay_half_life_days: 30,
1179 mmr_enabled: false,
1180 mmr_lambda: 0.7,
1181 token_counter: Arc::new(TokenCounter::new()),
1182 }
1183 }
1184
1185 #[tokio::test]
1186 async fn remember_saves_to_sqlite() {
1187 let memory = test_semantic_memory(false).await;
1188
1189 let cid = memory.sqlite.create_conversation().await.unwrap();
1190 let msg_id = memory.remember(cid, "user", "hello").await.unwrap();
1191
1192 assert_eq!(msg_id, MessageId(1));
1193
1194 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1195 assert_eq!(history.len(), 1);
1196 assert_eq!(history[0].role, Role::User);
1197 assert_eq!(history[0].content, "hello");
1198 }
1199
1200 #[tokio::test]
1201 async fn remember_with_parts_saves_parts_json() {
1202 let memory = test_semantic_memory(false).await;
1203 let cid = memory.sqlite.create_conversation().await.unwrap();
1204
1205 let parts_json =
1206 r#"[{"kind":"ToolOutput","tool_name":"shell","body":"hello","compacted_at":null}]"#;
1207 let (msg_id, _embedding_stored) = memory
1208 .remember_with_parts(cid, "assistant", "tool output", parts_json)
1209 .await
1210 .unwrap();
1211 assert!(msg_id > MessageId(0));
1212
1213 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1214 assert_eq!(history.len(), 1);
1215 assert_eq!(history[0].content, "tool output");
1216 }
1217
1218 #[tokio::test]
1219 async fn recall_returns_empty_without_qdrant() {
1220 let memory = test_semantic_memory(true).await;
1221
1222 let recalled = memory.recall("test", 5, None).await.unwrap();
1223 assert!(recalled.is_empty());
1224 }
1225
1226 #[tokio::test]
1227 async fn has_embedding_without_qdrant() {
1228 let memory = test_semantic_memory(true).await;
1229
1230 let has_embedding = memory.has_embedding(MessageId(1)).await.unwrap();
1231 assert!(!has_embedding);
1232 }
1233
1234 #[tokio::test]
1235 async fn embed_missing_without_qdrant() {
1236 let memory = test_semantic_memory(true).await;
1237
1238 let count = memory.embed_missing().await.unwrap();
1239 assert_eq!(count, 0);
1240 }
1241
1242 #[tokio::test]
1243 async fn sqlite_accessor() {
1244 let memory = test_semantic_memory(false).await;
1245
1246 let cid = memory.sqlite().create_conversation().await.unwrap();
1247 assert_eq!(cid, ConversationId(1));
1248
1249 memory
1250 .sqlite()
1251 .save_message(cid, "user", "test")
1252 .await
1253 .unwrap();
1254
1255 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1256 assert_eq!(history.len(), 1);
1257 }
1258
1259 #[tokio::test]
1260 async fn has_vector_store_returns_false_when_unavailable() {
1261 let memory = test_semantic_memory(false).await;
1262 assert!(!memory.has_vector_store());
1263 }
1264
1265 #[tokio::test]
1266 async fn is_vector_store_connected_returns_false_when_unavailable() {
1267 let memory = test_semantic_memory(false).await;
1268 assert!(!memory.is_vector_store_connected().await);
1269 }
1270
1271 #[tokio::test]
1272 async fn recall_returns_empty_when_embeddings_not_supported() {
1273 let memory = test_semantic_memory(false).await;
1274
1275 let recalled = memory.recall("test", 5, None).await.unwrap();
1276 assert!(recalled.is_empty());
1277 }
1278
1279 #[tokio::test]
1280 async fn embed_missing_returns_zero_when_embeddings_not_supported() {
1281 let memory = test_semantic_memory(false).await;
1282
1283 let cid = memory.sqlite().create_conversation().await.unwrap();
1284 memory
1285 .sqlite()
1286 .save_message(cid, "user", "test")
1287 .await
1288 .unwrap();
1289
1290 let count = memory.embed_missing().await.unwrap();
1291 assert_eq!(count, 0);
1292 }
1293
1294 #[tokio::test]
1295 async fn message_count_empty_conversation() {
1296 let memory = test_semantic_memory(false).await;
1297 let cid = memory.sqlite().create_conversation().await.unwrap();
1298
1299 let count = memory.message_count(cid).await.unwrap();
1300 assert_eq!(count, 0);
1301 }
1302
1303 #[tokio::test]
1304 async fn message_count_after_saves() {
1305 let memory = test_semantic_memory(false).await;
1306 let cid = memory.sqlite().create_conversation().await.unwrap();
1307
1308 memory.remember(cid, "user", "msg1").await.unwrap();
1309 memory.remember(cid, "assistant", "msg2").await.unwrap();
1310
1311 let count = memory.message_count(cid).await.unwrap();
1312 assert_eq!(count, 2);
1313 }
1314
1315 #[tokio::test]
1316 async fn unsummarized_count_decreases_after_summary() {
1317 let memory = test_semantic_memory(false).await;
1318 let cid = memory.sqlite().create_conversation().await.unwrap();
1319
1320 for i in 0..10 {
1321 memory
1322 .remember(cid, "user", &format!("msg{i}"))
1323 .await
1324 .unwrap();
1325 }
1326 assert_eq!(memory.unsummarized_message_count(cid).await.unwrap(), 10);
1327
1328 memory.summarize(cid, 5).await.unwrap();
1329
1330 assert!(memory.unsummarized_message_count(cid).await.unwrap() < 10);
1331 assert_eq!(memory.message_count(cid).await.unwrap(), 10);
1332 }
1333
1334 #[tokio::test]
1335 async fn load_summaries_empty() {
1336 let memory = test_semantic_memory(false).await;
1337 let cid = memory.sqlite().create_conversation().await.unwrap();
1338
1339 let summaries = memory.load_summaries(cid).await.unwrap();
1340 assert!(summaries.is_empty());
1341 }
1342
1343 #[tokio::test]
1344 async fn load_summaries_ordered() {
1345 let memory = test_semantic_memory(false).await;
1346 let cid = memory.sqlite().create_conversation().await.unwrap();
1347
1348 let msg_id1 = memory.remember(cid, "user", "m1").await.unwrap();
1349 let msg_id2 = memory.remember(cid, "assistant", "m2").await.unwrap();
1350 let msg_id3 = memory.remember(cid, "user", "m3").await.unwrap();
1351
1352 let s1 = memory
1353 .sqlite()
1354 .save_summary(cid, "summary1", msg_id1, msg_id2, 3)
1355 .await
1356 .unwrap();
1357 let s2 = memory
1358 .sqlite()
1359 .save_summary(cid, "summary2", msg_id2, msg_id3, 3)
1360 .await
1361 .unwrap();
1362
1363 let summaries = memory.load_summaries(cid).await.unwrap();
1364 assert_eq!(summaries.len(), 2);
1365 assert_eq!(summaries[0].id, s1);
1366 assert_eq!(summaries[0].content, "summary1");
1367 assert_eq!(summaries[1].id, s2);
1368 assert_eq!(summaries[1].content, "summary2");
1369 }
1370
1371 #[tokio::test]
1372 async fn summarize_below_threshold() {
1373 let memory = test_semantic_memory(false).await;
1374 let cid = memory.sqlite().create_conversation().await.unwrap();
1375
1376 memory.remember(cid, "user", "hello").await.unwrap();
1377
1378 let result = memory.summarize(cid, 10).await.unwrap();
1379 assert!(result.is_none());
1380 }
1381
1382 #[tokio::test]
1383 async fn summarize_stores_summary() {
1384 let memory = test_semantic_memory(false).await;
1385 let cid = memory.sqlite().create_conversation().await.unwrap();
1386
1387 for i in 0..5 {
1388 memory
1389 .remember(cid, "user", &format!("message {i}"))
1390 .await
1391 .unwrap();
1392 }
1393
1394 let summary_id = memory.summarize(cid, 3).await.unwrap();
1395 assert!(summary_id.is_some());
1396
1397 let summaries = memory.load_summaries(cid).await.unwrap();
1398 assert_eq!(summaries.len(), 1);
1399 assert_eq!(summaries[0].id, summary_id.unwrap());
1400 assert!(!summaries[0].content.is_empty());
1401 }
1402
1403 #[tokio::test]
1404 async fn summarize_respects_previous_summaries() {
1405 let memory = test_semantic_memory(false).await;
1406 let cid = memory.sqlite().create_conversation().await.unwrap();
1407
1408 for i in 0..10 {
1409 memory
1410 .remember(cid, "user", &format!("message {i}"))
1411 .await
1412 .unwrap();
1413 }
1414
1415 let s1 = memory.summarize(cid, 3).await.unwrap();
1416 assert!(s1.is_some());
1417
1418 let s2 = memory.summarize(cid, 3).await.unwrap();
1419 assert!(s2.is_some());
1420
1421 let summaries = memory.load_summaries(cid).await.unwrap();
1422 assert_eq!(summaries.len(), 2);
1423 assert!(summaries[0].last_message_id < summaries[1].first_message_id);
1424 }
1425
1426 #[tokio::test]
1427 async fn remember_multiple_messages_increments_ids() {
1428 let memory = test_semantic_memory(false).await;
1429 let cid = memory.sqlite.create_conversation().await.unwrap();
1430
1431 let id1 = memory.remember(cid, "user", "first").await.unwrap();
1432 let id2 = memory.remember(cid, "assistant", "second").await.unwrap();
1433 let id3 = memory.remember(cid, "user", "third").await.unwrap();
1434
1435 assert!(id1 < id2);
1436 assert!(id2 < id3);
1437 }
1438
1439 #[tokio::test]
1440 async fn message_count_across_conversations() {
1441 let memory = test_semantic_memory(false).await;
1442 let cid1 = memory.sqlite().create_conversation().await.unwrap();
1443 let cid2 = memory.sqlite().create_conversation().await.unwrap();
1444
1445 memory.remember(cid1, "user", "msg1").await.unwrap();
1446 memory.remember(cid1, "user", "msg2").await.unwrap();
1447 memory.remember(cid2, "user", "msg3").await.unwrap();
1448
1449 assert_eq!(memory.message_count(cid1).await.unwrap(), 2);
1450 assert_eq!(memory.message_count(cid2).await.unwrap(), 1);
1451 }
1452
1453 #[tokio::test]
1454 async fn summarize_exact_threshold_returns_none() {
1455 let memory = test_semantic_memory(false).await;
1456 let cid = memory.sqlite().create_conversation().await.unwrap();
1457
1458 for i in 0..3 {
1459 memory
1460 .remember(cid, "user", &format!("msg {i}"))
1461 .await
1462 .unwrap();
1463 }
1464
1465 let result = memory.summarize(cid, 3).await.unwrap();
1466 assert!(result.is_none());
1467 }
1468
1469 #[tokio::test]
1470 async fn summarize_one_above_threshold_produces_summary() {
1471 let memory = test_semantic_memory(false).await;
1472 let cid = memory.sqlite().create_conversation().await.unwrap();
1473
1474 for i in 0..4 {
1475 memory
1476 .remember(cid, "user", &format!("msg {i}"))
1477 .await
1478 .unwrap();
1479 }
1480
1481 let result = memory.summarize(cid, 3).await.unwrap();
1482 assert!(result.is_some());
1483 }
1484
1485 #[tokio::test]
1486 async fn summary_fields_populated() {
1487 let memory = test_semantic_memory(false).await;
1488 let cid = memory.sqlite().create_conversation().await.unwrap();
1489
1490 for i in 0..5 {
1491 memory
1492 .remember(cid, "user", &format!("msg {i}"))
1493 .await
1494 .unwrap();
1495 }
1496
1497 memory.summarize(cid, 3).await.unwrap();
1498 let summaries = memory.load_summaries(cid).await.unwrap();
1499 let s = &summaries[0];
1500
1501 assert_eq!(s.conversation_id, cid);
1502 assert!(s.first_message_id > MessageId(0));
1503 assert!(s.last_message_id >= s.first_message_id);
1504 assert!(s.token_estimate >= 0);
1505 assert!(!s.content.is_empty());
1506 }
1507
1508 #[test]
1509 fn build_summarization_prompt_format() {
1510 let messages = vec![
1511 (MessageId(1), "user".into(), "Hello".into()),
1512 (MessageId(2), "assistant".into(), "Hi there".into()),
1513 ];
1514 let prompt = build_summarization_prompt(&messages);
1515 assert!(prompt.contains("user: Hello"));
1516 assert!(prompt.contains("assistant: Hi there"));
1517 assert!(prompt.contains("key_facts"));
1518 }
1519
1520 #[test]
1521 fn build_summarization_prompt_empty() {
1522 let messages: Vec<(MessageId, String, String)> = vec![];
1523 let prompt = build_summarization_prompt(&messages);
1524 assert!(prompt.contains("key_facts"));
1525 }
1526
1527 #[test]
1528 fn structured_summary_deserialize() {
1529 let json = r#"{"summary":"s","key_facts":["f1","f2"],"entities":["e1"]}"#;
1530 let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1531 assert_eq!(ss.summary, "s");
1532 assert_eq!(ss.key_facts.len(), 2);
1533 assert_eq!(ss.entities.len(), 1);
1534 }
1535
1536 #[test]
1537 fn structured_summary_empty_facts() {
1538 let json = r#"{"summary":"s","key_facts":[],"entities":[]}"#;
1539 let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1540 assert!(ss.key_facts.is_empty());
1541 assert!(ss.entities.is_empty());
1542 }
1543
1544 #[tokio::test]
1545 async fn search_key_facts_no_qdrant_empty() {
1546 let memory = test_semantic_memory(false).await;
1547 let facts = memory.search_key_facts("query", 5).await.unwrap();
1548 assert!(facts.is_empty());
1549 }
1550
1551 #[test]
1552 fn recalled_message_debug() {
1553 let recalled = RecalledMessage {
1554 message: Message {
1555 role: Role::User,
1556 content: "test".into(),
1557 parts: vec![],
1558 metadata: MessageMetadata::default(),
1559 },
1560 score: 0.95,
1561 };
1562 let dbg = format!("{recalled:?}");
1563 assert!(dbg.contains("RecalledMessage"));
1564 assert!(dbg.contains("0.95"));
1565 }
1566
1567 #[test]
1568 fn summary_clone() {
1569 let summary = Summary {
1570 id: 1,
1571 conversation_id: ConversationId(2),
1572 content: "test summary".into(),
1573 first_message_id: MessageId(1),
1574 last_message_id: MessageId(5),
1575 token_estimate: 10,
1576 };
1577 let cloned = summary.clone();
1578 assert_eq!(summary.id, cloned.id);
1579 assert_eq!(summary.content, cloned.content);
1580 }
1581
1582 #[tokio::test]
1583 async fn remember_preserves_role_mapping() {
1584 let memory = test_semantic_memory(false).await;
1585 let cid = memory.sqlite.create_conversation().await.unwrap();
1586
1587 memory.remember(cid, "user", "u").await.unwrap();
1588 memory.remember(cid, "assistant", "a").await.unwrap();
1589 memory.remember(cid, "system", "s").await.unwrap();
1590
1591 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1592 assert_eq!(history.len(), 3);
1593 assert_eq!(history[0].role, Role::User);
1594 assert_eq!(history[1].role, Role::Assistant);
1595 assert_eq!(history[2].role, Role::System);
1596 }
1597
1598 #[tokio::test]
1599 async fn new_with_invalid_qdrant_url_graceful() {
1600 let mut mock = MockProvider::default();
1601 mock.supports_embeddings = true;
1602 let provider = AnyProvider::Mock(mock);
1603 let result =
1604 SemanticMemory::new(":memory:", "http://127.0.0.1:1", provider, "test-model").await;
1605 assert!(result.is_ok());
1606 }
1607
1608 #[tokio::test]
1609 async fn test_semantic_memory_sqlite_remember_recall_roundtrip() {
1610 let mut mock = MockProvider::default();
1612 mock.supports_embeddings = true;
1613 let provider = AnyProvider::Mock(mock);
1616
1617 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1618 let pool = sqlite.pool().clone();
1619 let qdrant = Some(crate::embedding_store::EmbeddingStore::new_sqlite(pool));
1620
1621 let memory = SemanticMemory {
1622 sqlite,
1623 qdrant,
1624 provider,
1625 embedding_model: "test-model".into(),
1626 vector_weight: 0.7,
1627 keyword_weight: 0.3,
1628 temporal_decay_enabled: false,
1629 temporal_decay_half_life_days: 30,
1630 mmr_enabled: false,
1631 mmr_lambda: 0.7,
1632 token_counter: Arc::new(TokenCounter::new()),
1633 };
1634
1635 let cid = memory.sqlite().create_conversation().await.unwrap();
1636
1637 let id1 = memory
1639 .remember(cid, "user", "rust async programming")
1640 .await
1641 .unwrap();
1642 let id2 = memory
1643 .remember(cid, "assistant", "use tokio for async")
1644 .await
1645 .unwrap();
1646 assert!(id1 < id2);
1647
1648 let recalled = memory.recall("rust", 5, None).await.unwrap();
1650 assert!(
1651 !recalled.is_empty(),
1652 "recall must return at least one result"
1653 );
1654
1655 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1657 assert_eq!(history.len(), 2);
1658 assert_eq!(history[0].content, "rust async programming");
1659 }
1660
1661 #[tokio::test]
1662 async fn remember_with_embeddings_supported_but_no_qdrant() {
1663 let memory = test_semantic_memory(true).await;
1664 let cid = memory.sqlite.create_conversation().await.unwrap();
1665
1666 let msg_id = memory.remember(cid, "user", "hello embed").await.unwrap();
1667 assert!(msg_id > MessageId(0));
1668
1669 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1670 assert_eq!(history.len(), 1);
1671 assert_eq!(history[0].content, "hello embed");
1672 }
1673
1674 #[tokio::test]
1675 async fn remember_verifies_content_via_load_history() {
1676 let memory = test_semantic_memory(false).await;
1677 let cid = memory.sqlite.create_conversation().await.unwrap();
1678
1679 memory.remember(cid, "user", "alpha").await.unwrap();
1680 memory.remember(cid, "assistant", "beta").await.unwrap();
1681 memory.remember(cid, "user", "gamma").await.unwrap();
1682
1683 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1684 assert_eq!(history.len(), 3);
1685 assert_eq!(history[0].content, "alpha");
1686 assert_eq!(history[1].content, "beta");
1687 assert_eq!(history[2].content, "gamma");
1688 }
1689
1690 #[tokio::test]
1691 async fn message_count_multiple_conversations_isolated() {
1692 let memory = test_semantic_memory(false).await;
1693 let cid1 = memory.sqlite().create_conversation().await.unwrap();
1694 let cid2 = memory.sqlite().create_conversation().await.unwrap();
1695 let cid3 = memory.sqlite().create_conversation().await.unwrap();
1696
1697 for _ in 0..5 {
1698 memory.remember(cid1, "user", "msg").await.unwrap();
1699 }
1700 for _ in 0..3 {
1701 memory.remember(cid2, "user", "msg").await.unwrap();
1702 }
1703
1704 assert_eq!(memory.message_count(cid1).await.unwrap(), 5);
1705 assert_eq!(memory.message_count(cid2).await.unwrap(), 3);
1706 assert_eq!(memory.message_count(cid3).await.unwrap(), 0);
1707 }
1708
1709 #[tokio::test]
1710 async fn summarize_empty_messages_range_returns_none() {
1711 let memory = test_semantic_memory(false).await;
1712 let cid = memory.sqlite().create_conversation().await.unwrap();
1713
1714 for i in 0..6 {
1715 memory
1716 .remember(cid, "user", &format!("msg {i}"))
1717 .await
1718 .unwrap();
1719 }
1720
1721 memory.summarize(cid, 3).await.unwrap();
1722 memory.summarize(cid, 3).await.unwrap();
1723
1724 let summaries = memory.load_summaries(cid).await.unwrap();
1725 assert_eq!(summaries.len(), 2);
1726 }
1727
1728 #[tokio::test]
1729 async fn summarize_token_estimate_populated() {
1730 let memory = test_semantic_memory(false).await;
1731 let cid = memory.sqlite().create_conversation().await.unwrap();
1732
1733 for i in 0..5 {
1734 memory
1735 .remember(cid, "user", &format!("message {i}"))
1736 .await
1737 .unwrap();
1738 }
1739
1740 memory.summarize(cid, 3).await.unwrap();
1741 let summaries = memory.load_summaries(cid).await.unwrap();
1742 let token_est = summaries[0].token_estimate;
1743 assert!(token_est > 0);
1744 }
1745
1746 #[tokio::test]
1747 async fn summarize_fails_when_provider_chat_fails() {
1748 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1749 let provider = AnyProvider::Ollama(zeph_llm::ollama::OllamaProvider::new(
1750 "http://127.0.0.1:1",
1751 "test".into(),
1752 "embed".into(),
1753 ));
1754 let memory = SemanticMemory {
1755 sqlite,
1756 qdrant: None,
1757 provider,
1758 embedding_model: "test".into(),
1759 vector_weight: 0.7,
1760 keyword_weight: 0.3,
1761 temporal_decay_enabled: false,
1762 temporal_decay_half_life_days: 30,
1763 mmr_enabled: false,
1764 mmr_lambda: 0.7,
1765 token_counter: Arc::new(TokenCounter::new()),
1766 };
1767 let cid = memory.sqlite().create_conversation().await.unwrap();
1768
1769 for i in 0..5 {
1770 memory
1771 .remember(cid, "user", &format!("msg {i}"))
1772 .await
1773 .unwrap();
1774 }
1775
1776 let result = memory.summarize(cid, 3).await;
1777 assert!(result.is_err());
1778 }
1779
1780 #[tokio::test]
1781 async fn embed_missing_without_embedding_support_returns_zero() {
1782 let memory = test_semantic_memory(false).await;
1783 let cid = memory.sqlite().create_conversation().await.unwrap();
1784 memory
1785 .sqlite()
1786 .save_message(cid, "user", "test message")
1787 .await
1788 .unwrap();
1789
1790 let count = memory.embed_missing().await.unwrap();
1791 assert_eq!(count, 0);
1792 }
1793
1794 #[tokio::test]
1795 async fn has_embedding_returns_false_when_no_qdrant() {
1796 let memory = test_semantic_memory(false).await;
1797 let cid = memory.sqlite.create_conversation().await.unwrap();
1798 let msg_id = memory.remember(cid, "user", "test").await.unwrap();
1799 assert!(!memory.has_embedding(msg_id).await.unwrap());
1800 }
1801
1802 #[tokio::test]
1803 async fn recall_empty_without_qdrant_regardless_of_filter() {
1804 let memory = test_semantic_memory(true).await;
1805 let filter = SearchFilter {
1806 conversation_id: Some(ConversationId(1)),
1807 role: None,
1808 };
1809 let recalled = memory.recall("query", 10, Some(filter)).await.unwrap();
1810 assert!(recalled.is_empty());
1811 }
1812
1813 #[tokio::test]
1814 async fn summarize_message_range_bounds() {
1815 let memory = test_semantic_memory(false).await;
1816 let cid = memory.sqlite().create_conversation().await.unwrap();
1817
1818 for i in 0..8 {
1819 memory
1820 .remember(cid, "user", &format!("msg {i}"))
1821 .await
1822 .unwrap();
1823 }
1824
1825 let summary_id = memory.summarize(cid, 4).await.unwrap().unwrap();
1826 let summaries = memory.load_summaries(cid).await.unwrap();
1827 assert_eq!(summaries.len(), 1);
1828 assert_eq!(summaries[0].id, summary_id);
1829 assert!(summaries[0].first_message_id >= MessageId(1));
1830 assert!(summaries[0].last_message_id >= summaries[0].first_message_id);
1831 }
1832
1833 #[test]
1834 fn build_summarization_prompt_preserves_order() {
1835 let messages = vec![
1836 (MessageId(1), "user".into(), "first".into()),
1837 (MessageId(2), "assistant".into(), "second".into()),
1838 (MessageId(3), "user".into(), "third".into()),
1839 ];
1840 let prompt = build_summarization_prompt(&messages);
1841 let first_pos = prompt.find("user: first").unwrap();
1842 let second_pos = prompt.find("assistant: second").unwrap();
1843 let third_pos = prompt.find("user: third").unwrap();
1844 assert!(first_pos < second_pos);
1845 assert!(second_pos < third_pos);
1846 }
1847
1848 #[test]
1849 fn summary_debug() {
1850 let summary = Summary {
1851 id: 1,
1852 conversation_id: ConversationId(2),
1853 content: "test".into(),
1854 first_message_id: MessageId(1),
1855 last_message_id: MessageId(5),
1856 token_estimate: 10,
1857 };
1858 let dbg = format!("{summary:?}");
1859 assert!(dbg.contains("Summary"));
1860 }
1861
1862 #[tokio::test]
1863 async fn message_count_nonexistent_conversation() {
1864 let memory = test_semantic_memory(false).await;
1865 let count = memory.message_count(ConversationId(999)).await.unwrap();
1866 assert_eq!(count, 0);
1867 }
1868
1869 #[tokio::test]
1870 async fn load_summaries_nonexistent_conversation() {
1871 let memory = test_semantic_memory(false).await;
1872 let summaries = memory.load_summaries(ConversationId(999)).await.unwrap();
1873 assert!(summaries.is_empty());
1874 }
1875
1876 #[tokio::test]
1877 async fn store_session_summary_no_qdrant_noop() {
1878 let memory = test_semantic_memory(true).await;
1879 let result = memory
1880 .store_session_summary(ConversationId(1), "test summary")
1881 .await;
1882 assert!(result.is_ok());
1883 }
1884
1885 #[tokio::test]
1886 async fn store_session_summary_no_embeddings_noop() {
1887 let memory = test_semantic_memory(false).await;
1888 let result = memory
1889 .store_session_summary(ConversationId(1), "test summary")
1890 .await;
1891 assert!(result.is_ok());
1892 }
1893
1894 #[tokio::test]
1895 async fn search_session_summaries_no_qdrant_empty() {
1896 let memory = test_semantic_memory(true).await;
1897 let results = memory
1898 .search_session_summaries("query", 5, None)
1899 .await
1900 .unwrap();
1901 assert!(results.is_empty());
1902 }
1903
1904 #[tokio::test]
1905 async fn search_session_summaries_no_embeddings_empty() {
1906 let memory = test_semantic_memory(false).await;
1907 let results = memory
1908 .search_session_summaries("query", 5, Some(ConversationId(1)))
1909 .await
1910 .unwrap();
1911 assert!(results.is_empty());
1912 }
1913
1914 #[test]
1915 fn session_summary_result_debug() {
1916 let result = SessionSummaryResult {
1917 summary_text: "test".into(),
1918 score: 0.9,
1919 conversation_id: ConversationId(1),
1920 };
1921 let dbg = format!("{result:?}");
1922 assert!(dbg.contains("SessionSummaryResult"));
1923 }
1924
1925 #[test]
1926 fn session_summary_result_clone() {
1927 let result = SessionSummaryResult {
1928 summary_text: "test".into(),
1929 score: 0.9,
1930 conversation_id: ConversationId(1),
1931 };
1932 let cloned = result.clone();
1933 assert_eq!(result.summary_text, cloned.summary_text);
1934 assert_eq!(result.conversation_id, cloned.conversation_id);
1935 }
1936
1937 #[tokio::test]
1938 async fn recall_fts5_fallback_without_qdrant() {
1939 let memory = test_semantic_memory(false).await;
1940 let cid = memory.sqlite.create_conversation().await.unwrap();
1941
1942 memory
1943 .remember(cid, "user", "rust programming guide")
1944 .await
1945 .unwrap();
1946 memory
1947 .remember(cid, "assistant", "python tutorial")
1948 .await
1949 .unwrap();
1950 memory
1951 .remember(cid, "user", "advanced rust patterns")
1952 .await
1953 .unwrap();
1954
1955 let recalled = memory.recall("rust", 5, None).await.unwrap();
1956 assert_eq!(recalled.len(), 2);
1957 assert!(recalled[0].score >= recalled[1].score);
1958 }
1959
1960 #[tokio::test]
1961 async fn recall_fts5_fallback_with_filter() {
1962 let memory = test_semantic_memory(false).await;
1963 let cid1 = memory.sqlite.create_conversation().await.unwrap();
1964 let cid2 = memory.sqlite.create_conversation().await.unwrap();
1965
1966 memory.remember(cid1, "user", "hello world").await.unwrap();
1967 memory
1968 .remember(cid2, "user", "hello universe")
1969 .await
1970 .unwrap();
1971
1972 let filter = SearchFilter {
1973 conversation_id: Some(cid1),
1974 role: None,
1975 };
1976 let recalled = memory.recall("hello", 5, Some(filter)).await.unwrap();
1977 assert_eq!(recalled.len(), 1);
1978 }
1979
1980 #[tokio::test]
1981 async fn recall_fts5_no_matches_returns_empty() {
1982 let memory = test_semantic_memory(false).await;
1983 let cid = memory.sqlite.create_conversation().await.unwrap();
1984
1985 memory.remember(cid, "user", "hello world").await.unwrap();
1986
1987 let recalled = memory.recall("nonexistent", 5, None).await.unwrap();
1988 assert!(recalled.is_empty());
1989 }
1990
1991 #[tokio::test]
1992 async fn recall_fts5_respects_limit() {
1993 let memory = test_semantic_memory(false).await;
1994 let cid = memory.sqlite.create_conversation().await.unwrap();
1995
1996 for i in 0..10 {
1997 memory
1998 .remember(cid, "user", &format!("test message number {i}"))
1999 .await
2000 .unwrap();
2001 }
2002
2003 let recalled = memory.recall("test", 3, None).await.unwrap();
2004 assert_eq!(recalled.len(), 3);
2005 }
2006
2007 #[tokio::test]
2010 async fn summarize_fallback_to_plain_text_when_structured_fails() {
2011 let sqlite = SqliteStore::new(":memory:").await.unwrap();
2019 let mut mock = MockProvider::default();
2020 mock.default_response = "plain text summary".into();
2022 let provider = AnyProvider::Mock(mock);
2023
2024 let memory = SemanticMemory {
2025 sqlite,
2026 qdrant: None,
2027 provider,
2028 embedding_model: "test".into(),
2029 vector_weight: 0.7,
2030 keyword_weight: 0.3,
2031 temporal_decay_enabled: false,
2032 temporal_decay_half_life_days: 30,
2033 mmr_enabled: false,
2034 mmr_lambda: 0.7,
2035 token_counter: Arc::new(TokenCounter::new()),
2036 };
2037
2038 let cid = memory.sqlite().create_conversation().await.unwrap();
2039 for i in 0..5 {
2040 memory
2041 .remember(cid, "user", &format!("msg {i}"))
2042 .await
2043 .unwrap();
2044 }
2045
2046 let result = memory.summarize(cid, 3).await;
2047 assert!(result.is_ok());
2053 let summaries = memory.load_summaries(cid).await.unwrap();
2054 assert_eq!(summaries.len(), 1);
2055 assert!(!summaries[0].content.is_empty());
2056 }
2057
2058 #[test]
2061 fn temporal_decay_disabled_leaves_scores_unchanged() {
2062 let mut ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
2063 let timestamps = std::collections::HashMap::new();
2064 apply_temporal_decay(&mut ranked, ×tamps, 30);
2065 assert!((ranked[0].1 - 1.0).abs() < f64::EPSILON);
2066 assert!((ranked[1].1 - 0.5).abs() < f64::EPSILON);
2067 }
2068
2069 #[test]
2070 fn temporal_decay_zero_age_preserves_score() {
2071 let now = std::time::SystemTime::now()
2072 .duration_since(std::time::UNIX_EPOCH)
2073 .unwrap_or_default()
2074 .as_secs()
2075 .cast_signed();
2076 let mut ranked = vec![(MessageId(1), 1.0f64)];
2077 let mut timestamps = std::collections::HashMap::new();
2078 timestamps.insert(MessageId(1), now);
2079 apply_temporal_decay(&mut ranked, ×tamps, 30);
2080 assert!((ranked[0].1 - 1.0).abs() < 0.01);
2082 }
2083
2084 #[test]
2085 fn temporal_decay_half_life_halves_score() {
2086 let half_life = 30u32;
2088 let age_secs = i64::from(half_life) * 86400;
2089 let now = std::time::SystemTime::now()
2090 .duration_since(std::time::UNIX_EPOCH)
2091 .unwrap_or_default()
2092 .as_secs()
2093 .cast_signed();
2094 let ts = now - age_secs;
2095 let mut ranked = vec![(MessageId(1), 1.0f64)];
2096 let mut timestamps = std::collections::HashMap::new();
2097 timestamps.insert(MessageId(1), ts);
2098 apply_temporal_decay(&mut ranked, ×tamps, half_life);
2099 assert!(
2101 (ranked[0].1 - 0.5).abs() < 0.01,
2102 "score was {}",
2103 ranked[0].1
2104 );
2105 }
2106
2107 #[test]
2110 fn mmr_empty_input_returns_empty() {
2111 let ranked = vec![];
2112 let vectors = std::collections::HashMap::new();
2113 let result = apply_mmr(&ranked, &vectors, 0.7, 5);
2114 assert!(result.is_empty());
2115 }
2116
2117 #[test]
2118 fn mmr_returns_up_to_limit() {
2119 let ranked = vec![
2120 (MessageId(1), 1.0f64),
2121 (MessageId(2), 0.9f64),
2122 (MessageId(3), 0.8f64),
2123 ];
2124 let mut vectors = std::collections::HashMap::new();
2125 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2126 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2127 vectors.insert(MessageId(3), vec![1.0f32, 0.0]);
2128 let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2129 assert_eq!(result.len(), 2);
2130 }
2131
2132 #[test]
2133 fn mmr_without_vectors_picks_by_relevance() {
2134 let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
2135 let vectors = std::collections::HashMap::new();
2136 let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2137 assert_eq!(result.len(), 2);
2138 assert_eq!(result[0].0, MessageId(1));
2139 }
2140
2141 #[test]
2142 fn mmr_prefers_diverse_over_redundant() {
2143 let ranked = vec![
2145 (MessageId(1), 1.0f64), (MessageId(2), 0.9f64), (MessageId(3), 0.9f64), ];
2149 let mut vectors = std::collections::HashMap::new();
2150 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2151 vectors.insert(MessageId(2), vec![0.0f32, 1.0]); vectors.insert(MessageId(3), vec![1.0f32, 0.0]); let result = apply_mmr(&ranked, &vectors, 0.5, 2);
2154 assert_eq!(result.len(), 2);
2155 assert_eq!(result[0].0, MessageId(1));
2156 assert_eq!(result[1].0, MessageId(2));
2158 }
2159
2160 #[test]
2161 fn temporal_decay_half_life_zero_is_noop() {
2162 let now = std::time::SystemTime::now()
2163 .duration_since(std::time::UNIX_EPOCH)
2164 .unwrap_or_default()
2165 .as_secs()
2166 .cast_signed();
2167 let age_secs = 30i64 * 86400;
2168 let ts = now - age_secs;
2169 let mut ranked = vec![(MessageId(1), 1.0f64)];
2170 let mut timestamps = std::collections::HashMap::new();
2171 timestamps.insert(MessageId(1), ts);
2172 apply_temporal_decay(&mut ranked, ×tamps, 0);
2174 assert!(
2175 (ranked[0].1 - 1.0).abs() < f64::EPSILON,
2176 "score was {}",
2177 ranked[0].1
2178 );
2179 }
2180
2181 #[test]
2182 fn temporal_decay_huge_age_near_zero() {
2183 let now = std::time::SystemTime::now()
2184 .duration_since(std::time::UNIX_EPOCH)
2185 .unwrap_or_default()
2186 .as_secs()
2187 .cast_signed();
2188 let age_secs = 3650i64 * 86400;
2190 let ts = now - age_secs;
2191 let mut ranked = vec![(MessageId(1), 1.0f64)];
2192 let mut timestamps = std::collections::HashMap::new();
2193 timestamps.insert(MessageId(1), ts);
2194 apply_temporal_decay(&mut ranked, ×tamps, 30);
2195 assert!(ranked[0].1 < 0.001, "score was {}", ranked[0].1);
2197 }
2198
2199 #[test]
2200 fn temporal_decay_small_half_life() {
2201 let now = std::time::SystemTime::now()
2203 .duration_since(std::time::UNIX_EPOCH)
2204 .unwrap_or_default()
2205 .as_secs()
2206 .cast_signed();
2207 let ts = now - 7 * 86400i64;
2208 let mut ranked = vec![(MessageId(1), 1.0f64)];
2209 let mut timestamps = std::collections::HashMap::new();
2210 timestamps.insert(MessageId(1), ts);
2211 apply_temporal_decay(&mut ranked, ×tamps, 1);
2212 assert!(ranked[0].1 < 0.01, "score was {}", ranked[0].1);
2213 }
2214
2215 #[test]
2216 fn mmr_lambda_zero_max_diversity() {
2217 let ranked = vec![
2219 (MessageId(1), 1.0f64), (MessageId(2), 0.9f64), (MessageId(3), 0.85f64), ];
2223 let mut vectors = std::collections::HashMap::new();
2224 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2225 vectors.insert(MessageId(2), vec![0.0f32, 1.0]); vectors.insert(MessageId(3), vec![1.0f32, 0.0]); let result = apply_mmr(&ranked, &vectors, 0.0, 3);
2228 assert_eq!(result.len(), 3);
2229 assert_eq!(result[1].0, MessageId(2));
2231 }
2232
2233 #[test]
2234 fn mmr_lambda_one_pure_relevance() {
2235 let ranked = vec![
2237 (MessageId(1), 1.0f64),
2238 (MessageId(2), 0.8f64),
2239 (MessageId(3), 0.6f64),
2240 ];
2241 let mut vectors = std::collections::HashMap::new();
2242 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2243 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2244 vectors.insert(MessageId(3), vec![0.5f32, 0.5]);
2245 let result = apply_mmr(&ranked, &vectors, 1.0, 3);
2246 assert_eq!(result.len(), 3);
2247 assert_eq!(result[0].0, MessageId(1));
2248 assert_eq!(result[1].0, MessageId(2));
2249 assert_eq!(result[2].0, MessageId(3));
2250 }
2251
2252 #[test]
2253 fn mmr_limit_zero_returns_empty() {
2254 let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.8f64)];
2255 let mut vectors = std::collections::HashMap::new();
2256 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2257 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2258 let result = apply_mmr(&ranked, &vectors, 0.7, 0);
2259 assert!(result.is_empty());
2260 }
2261
2262 #[test]
2263 fn mmr_duplicate_vectors_penalizes_second() {
2264 let ranked = vec![
2266 (MessageId(1), 1.0f64),
2267 (MessageId(2), 1.0f64), (MessageId(3), 0.9f64), ];
2270 let mut vectors = std::collections::HashMap::new();
2271 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2272 vectors.insert(MessageId(2), vec![1.0f32, 0.0]); vectors.insert(MessageId(3), vec![0.0f32, 1.0]); let result = apply_mmr(&ranked, &vectors, 0.5, 3);
2275 assert_eq!(result.len(), 3);
2276 assert_eq!(result[0].0, MessageId(1));
2277 assert_eq!(result[1].0, MessageId(3));
2279 }
2280
2281 use proptest::prelude::*;
2284
2285 proptest! {
2286 #[test]
2287 fn count_tokens_never_panics(s in ".*") {
2288 let counter = crate::token_counter::TokenCounter::new();
2289 let _ = counter.count_tokens(&s);
2290 }
2291 }
2292}