1use zeph_llm::any::AnyProvider;
5use zeph_llm::provider::{LlmProvider, Message, MessageMetadata, Role};
6
7use std::sync::Arc;
8
9use crate::embedding_store::{EmbeddingStore, MessageKind, SearchFilter};
10use crate::error::MemoryError;
11use crate::sqlite::SqliteStore;
12use crate::token_counter::TokenCounter;
13use crate::types::{ConversationId, MessageId};
14use crate::vector_store::{FieldCondition, FieldValue, VectorFilter};
15
16const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
17const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
18
19#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
20pub struct StructuredSummary {
21 pub summary: String,
22 pub key_facts: Vec<String>,
23 pub entities: Vec<String>,
24}
25
26#[derive(Debug)]
27pub struct RecalledMessage {
28 pub message: Message,
29 pub score: f32,
30}
31
32#[derive(Debug, Clone)]
33pub struct Summary {
34 pub id: i64,
35 pub conversation_id: ConversationId,
36 pub content: String,
37 pub first_message_id: MessageId,
38 pub last_message_id: MessageId,
39 pub token_estimate: i64,
40}
41
42#[derive(Debug, Clone)]
43pub struct SessionSummaryResult {
44 pub summary_text: String,
45 pub score: f32,
46 pub conversation_id: ConversationId,
47}
48
49fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
50 if a.len() != b.len() || a.is_empty() {
51 return 0.0;
52 }
53 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
54 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
55 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
56 if norm_a == 0.0 || norm_b == 0.0 {
57 return 0.0;
58 }
59 dot / (norm_a * norm_b)
60}
61
62fn apply_temporal_decay(
63 ranked: &mut [(MessageId, f64)],
64 timestamps: &std::collections::HashMap<MessageId, i64>,
65 half_life_days: u32,
66) {
67 if half_life_days == 0 {
68 return;
69 }
70 let now = std::time::SystemTime::now()
71 .duration_since(std::time::UNIX_EPOCH)
72 .unwrap_or_default()
73 .as_secs()
74 .cast_signed();
75 let lambda = std::f64::consts::LN_2 / f64::from(half_life_days);
76
77 for (msg_id, score) in ranked.iter_mut() {
78 if let Some(&ts) = timestamps.get(msg_id) {
79 #[allow(clippy::cast_precision_loss)]
80 let age_days = (now - ts).max(0) as f64 / 86400.0;
81 *score *= (-lambda * age_days).exp();
82 }
83 }
84}
85
86fn apply_mmr(
87 ranked: &[(MessageId, f64)],
88 vectors: &std::collections::HashMap<MessageId, Vec<f32>>,
89 lambda: f32,
90 limit: usize,
91) -> Vec<(MessageId, f64)> {
92 if ranked.is_empty() || limit == 0 {
93 return Vec::new();
94 }
95
96 let lambda = f64::from(lambda);
97 let mut selected: Vec<(MessageId, f64)> = Vec::with_capacity(limit);
98 let mut remaining: Vec<(MessageId, f64)> = ranked.to_vec();
99
100 while selected.len() < limit && !remaining.is_empty() {
101 let best_idx = if selected.is_empty() {
102 0
104 } else {
105 let mut best = 0usize;
106 let mut best_score = f64::NEG_INFINITY;
107
108 for (i, &(cand_id, relevance)) in remaining.iter().enumerate() {
109 let max_sim = if let Some(cand_vec) = vectors.get(&cand_id) {
110 selected
111 .iter()
112 .filter_map(|(sel_id, _)| vectors.get(sel_id))
113 .map(|sel_vec| f64::from(cosine_similarity(cand_vec, sel_vec)))
114 .fold(f64::NEG_INFINITY, f64::max)
115 } else {
116 0.0
117 };
118 let max_sim = if max_sim == f64::NEG_INFINITY {
119 0.0
120 } else {
121 max_sim
122 };
123 let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim;
124 if mmr_score > best_score {
125 best_score = mmr_score;
126 best = i;
127 }
128 }
129 best
130 };
131
132 selected.push(remaining.remove(best_idx));
133 }
134
135 selected
136}
137
138fn build_summarization_prompt(messages: &[(MessageId, String, String)]) -> String {
139 let mut prompt = String::from(
140 "Summarize the following conversation. Extract key facts, decisions, entities, \
141 and context needed to continue the conversation.\n\n\
142 Respond in JSON with fields: summary (string), key_facts (list of strings), \
143 entities (list of strings).\n\nConversation:\n",
144 );
145
146 for (_, role, content) in messages {
147 prompt.push_str(role);
148 prompt.push_str(": ");
149 prompt.push_str(content);
150 prompt.push('\n');
151 }
152
153 prompt
154}
155
156pub struct SemanticMemory {
157 sqlite: SqliteStore,
158 qdrant: Option<EmbeddingStore>,
159 provider: AnyProvider,
160 embedding_model: String,
161 vector_weight: f64,
162 keyword_weight: f64,
163 temporal_decay_enabled: bool,
164 temporal_decay_half_life_days: u32,
165 mmr_enabled: bool,
166 mmr_lambda: f32,
167 pub token_counter: Arc<TokenCounter>,
168}
169
170impl SemanticMemory {
171 pub async fn new(
179 sqlite_path: &str,
180 qdrant_url: &str,
181 provider: AnyProvider,
182 embedding_model: &str,
183 ) -> Result<Self, MemoryError> {
184 Self::with_weights(sqlite_path, qdrant_url, provider, embedding_model, 0.7, 0.3).await
185 }
186
187 pub async fn with_weights(
193 sqlite_path: &str,
194 qdrant_url: &str,
195 provider: AnyProvider,
196 embedding_model: &str,
197 vector_weight: f64,
198 keyword_weight: f64,
199 ) -> Result<Self, MemoryError> {
200 let sqlite = SqliteStore::new(sqlite_path).await?;
201 let pool = sqlite.pool().clone();
202
203 let qdrant = match EmbeddingStore::new(qdrant_url, pool) {
204 Ok(store) => Some(store),
205 Err(e) => {
206 tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
207 None
208 }
209 };
210
211 Ok(Self {
212 sqlite,
213 qdrant,
214 provider,
215 embedding_model: embedding_model.into(),
216 vector_weight,
217 keyword_weight,
218 temporal_decay_enabled: false,
219 temporal_decay_half_life_days: 30,
220 mmr_enabled: false,
221 mmr_lambda: 0.7,
222 token_counter: Arc::new(TokenCounter::new()),
223 })
224 }
225
226 #[must_use]
228 pub fn with_ranking_options(
229 mut self,
230 temporal_decay_enabled: bool,
231 temporal_decay_half_life_days: u32,
232 mmr_enabled: bool,
233 mmr_lambda: f32,
234 ) -> Self {
235 self.temporal_decay_enabled = temporal_decay_enabled;
236 self.temporal_decay_half_life_days = temporal_decay_half_life_days;
237 self.mmr_enabled = mmr_enabled;
238 self.mmr_lambda = mmr_lambda;
239 self
240 }
241
242 pub async fn with_sqlite_backend(
248 sqlite_path: &str,
249 provider: AnyProvider,
250 embedding_model: &str,
251 vector_weight: f64,
252 keyword_weight: f64,
253 ) -> Result<Self, MemoryError> {
254 let sqlite = SqliteStore::new(sqlite_path).await?;
255 let pool = sqlite.pool().clone();
256 let store = EmbeddingStore::new_sqlite(pool);
257
258 Ok(Self {
259 sqlite,
260 qdrant: Some(store),
261 provider,
262 embedding_model: embedding_model.into(),
263 vector_weight,
264 keyword_weight,
265 temporal_decay_enabled: false,
266 temporal_decay_half_life_days: 30,
267 mmr_enabled: false,
268 mmr_lambda: 0.7,
269 token_counter: Arc::new(TokenCounter::new()),
270 })
271 }
272
273 pub async fn remember(
282 &self,
283 conversation_id: ConversationId,
284 role: &str,
285 content: &str,
286 ) -> Result<MessageId, MemoryError> {
287 let message_id = self
288 .sqlite
289 .save_message(conversation_id, role, content)
290 .await?;
291
292 if let Some(qdrant) = &self.qdrant
293 && self.provider.supports_embeddings()
294 {
295 match self.provider.embed(content).await {
296 Ok(vector) => {
297 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
299 if let Err(e) = qdrant.ensure_collection(vector_size).await {
300 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
301 } else if let Err(e) = qdrant
302 .store(
303 message_id,
304 conversation_id,
305 role,
306 vector,
307 MessageKind::Regular,
308 &self.embedding_model,
309 )
310 .await
311 {
312 tracing::warn!("Failed to store embedding: {e:#}");
313 }
314 }
315 Err(e) => {
316 tracing::warn!("Failed to generate embedding: {e:#}");
317 }
318 }
319 }
320
321 Ok(message_id)
322 }
323
324 pub async fn remember_with_parts(
333 &self,
334 conversation_id: ConversationId,
335 role: &str,
336 content: &str,
337 parts_json: &str,
338 ) -> Result<(MessageId, bool), MemoryError> {
339 let message_id = self
340 .sqlite
341 .save_message_with_parts(conversation_id, role, content, parts_json)
342 .await?;
343
344 let mut embedding_stored = false;
345
346 if let Some(qdrant) = &self.qdrant
347 && self.provider.supports_embeddings()
348 {
349 match self.provider.embed(content).await {
350 Ok(vector) => {
351 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
352 if let Err(e) = qdrant.ensure_collection(vector_size).await {
353 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
354 } else if let Err(e) = qdrant
355 .store(
356 message_id,
357 conversation_id,
358 role,
359 vector,
360 MessageKind::Regular,
361 &self.embedding_model,
362 )
363 .await
364 {
365 tracing::warn!("Failed to store embedding: {e:#}");
366 } else {
367 embedding_stored = true;
368 }
369 }
370 Err(e) => {
371 tracing::warn!("Failed to generate embedding: {e:#}");
372 }
373 }
374 }
375
376 Ok((message_id, embedding_stored))
377 }
378
379 pub async fn save_only(
387 &self,
388 conversation_id: ConversationId,
389 role: &str,
390 content: &str,
391 parts_json: &str,
392 ) -> Result<MessageId, MemoryError> {
393 self.sqlite
394 .save_message_with_parts(conversation_id, role, content, parts_json)
395 .await
396 }
397
398 #[allow(clippy::too_many_lines)]
408 pub async fn recall(
409 &self,
410 query: &str,
411 limit: usize,
412 filter: Option<SearchFilter>,
413 ) -> Result<Vec<RecalledMessage>, MemoryError> {
414 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
415
416 let keyword_results = match self
418 .sqlite
419 .keyword_search(query, limit * 2, conversation_id)
420 .await
421 {
422 Ok(results) => results,
423 Err(e) => {
424 tracing::warn!("FTS5 keyword search failed: {e:#}");
425 Vec::new()
426 }
427 };
428
429 let vector_results = if let Some(qdrant) = &self.qdrant
431 && self.provider.supports_embeddings()
432 {
433 let query_vector = self.provider.embed(query).await?;
434 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
435 qdrant.ensure_collection(vector_size).await?;
436 qdrant.search(&query_vector, limit * 2, filter).await?
437 } else {
438 Vec::new()
439 };
440
441 let mut scores: std::collections::HashMap<MessageId, f64> =
443 std::collections::HashMap::new();
444
445 if !vector_results.is_empty() {
446 let max_vs = vector_results
447 .iter()
448 .map(|r| r.score)
449 .fold(f32::NEG_INFINITY, f32::max);
450 let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
451 for r in &vector_results {
452 let normalized = f64::from(r.score / norm);
453 *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
454 }
455 }
456
457 if !keyword_results.is_empty() {
458 let max_ks = keyword_results
459 .iter()
460 .map(|r| r.1)
461 .fold(f64::NEG_INFINITY, f64::max);
462 let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
463 for &(msg_id, score) in &keyword_results {
464 let normalized = score / norm;
465 *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
466 }
467 }
468
469 if scores.is_empty() {
470 return Ok(Vec::new());
471 }
472
473 let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
475 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
476
477 if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
479 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
480 match self.sqlite.message_timestamps(&ids).await {
481 Ok(timestamps) => {
482 apply_temporal_decay(
483 &mut ranked,
484 ×tamps,
485 self.temporal_decay_half_life_days,
486 );
487 ranked
488 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
489 }
490 Err(e) => {
491 tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
492 }
493 }
494 }
495
496 if self.mmr_enabled && !vector_results.is_empty() {
498 if let Some(qdrant) = &self.qdrant {
499 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
500 match qdrant.get_vectors(&ids).await {
501 Ok(vec_map) if !vec_map.is_empty() => {
502 ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
503 }
504 Ok(_) => {
505 ranked.truncate(limit);
506 }
507 Err(e) => {
508 tracing::warn!("MMR: failed to fetch vectors: {e:#}");
509 ranked.truncate(limit);
510 }
511 }
512 } else {
513 ranked.truncate(limit);
514 }
515 } else {
516 ranked.truncate(limit);
517 }
518
519 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
520 let messages = self.sqlite.messages_by_ids(&ids).await?;
521 let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
522
523 let recalled = ranked
524 .iter()
525 .filter_map(|(msg_id, score)| {
526 msg_map.get(msg_id).map(|msg| RecalledMessage {
527 message: msg.clone(),
528 #[expect(clippy::cast_possible_truncation)]
529 score: *score as f32,
530 })
531 })
532 .collect();
533
534 Ok(recalled)
535 }
536
537 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
543 match &self.qdrant {
544 Some(qdrant) => qdrant.has_embedding(message_id).await,
545 None => Ok(false),
546 }
547 }
548
549 pub async fn embed_missing(&self) -> Result<usize, MemoryError> {
558 let Some(qdrant) = &self.qdrant else {
559 return Ok(0);
560 };
561 if !self.provider.supports_embeddings() {
562 return Ok(0);
563 }
564
565 let unembedded = self.sqlite.unembedded_message_ids(Some(1000)).await?;
566
567 if unembedded.is_empty() {
568 return Ok(0);
569 }
570
571 let probe = self.provider.embed("probe").await?;
572 let vector_size = u64::try_from(probe.len())?;
573 qdrant.ensure_collection(vector_size).await?;
574
575 let mut count = 0;
576 for (msg_id, conversation_id, role, content) in &unembedded {
577 match self.provider.embed(content).await {
578 Ok(vector) => {
579 if let Err(e) = qdrant
580 .store(
581 *msg_id,
582 *conversation_id,
583 role,
584 vector,
585 MessageKind::Regular,
586 &self.embedding_model,
587 )
588 .await
589 {
590 tracing::warn!("Failed to store embedding for msg {msg_id}: {e:#}");
591 continue;
592 }
593 count += 1;
594 }
595 Err(e) => {
596 tracing::warn!("Failed to embed msg {msg_id}: {e:#}");
597 }
598 }
599 }
600
601 tracing::info!("Embedded {count}/{} missing messages", unembedded.len());
602 Ok(count)
603 }
604
605 pub async fn store_session_summary(
611 &self,
612 conversation_id: ConversationId,
613 summary_text: &str,
614 ) -> Result<(), MemoryError> {
615 let Some(qdrant) = &self.qdrant else {
616 return Ok(());
617 };
618 if !self.provider.supports_embeddings() {
619 return Ok(());
620 }
621
622 let vector = self.provider.embed(summary_text).await?;
623 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
624 qdrant
625 .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
626 .await?;
627
628 let payload = serde_json::json!({
629 "conversation_id": conversation_id.0,
630 "summary_text": summary_text,
631 });
632
633 qdrant
634 .store_to_collection(SESSION_SUMMARIES_COLLECTION, payload, vector)
635 .await?;
636
637 tracing::debug!(
638 conversation_id = conversation_id.0,
639 "stored session summary"
640 );
641 Ok(())
642 }
643
644 pub async fn search_session_summaries(
650 &self,
651 query: &str,
652 limit: usize,
653 exclude_conversation_id: Option<ConversationId>,
654 ) -> Result<Vec<SessionSummaryResult>, MemoryError> {
655 let Some(qdrant) = &self.qdrant else {
656 return Ok(Vec::new());
657 };
658 if !self.provider.supports_embeddings() {
659 return Ok(Vec::new());
660 }
661
662 let vector = self.provider.embed(query).await?;
663 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
664 qdrant
665 .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
666 .await?;
667
668 let filter = exclude_conversation_id.map(|cid| VectorFilter {
669 must: vec![],
670 must_not: vec![FieldCondition {
671 field: "conversation_id".into(),
672 value: FieldValue::Integer(cid.0),
673 }],
674 });
675
676 let points = qdrant
677 .search_collection(SESSION_SUMMARIES_COLLECTION, &vector, limit, filter)
678 .await?;
679
680 let results = points
681 .into_iter()
682 .filter_map(|point| {
683 let summary_text = point.payload.get("summary_text")?.as_str()?.to_owned();
684 let conversation_id =
685 ConversationId(point.payload.get("conversation_id")?.as_i64()?);
686 Some(SessionSummaryResult {
687 summary_text,
688 score: point.score,
689 conversation_id,
690 })
691 })
692 .collect();
693
694 Ok(results)
695 }
696
697 #[must_use]
699 pub fn sqlite(&self) -> &SqliteStore {
700 &self.sqlite
701 }
702
703 pub async fn is_vector_store_connected(&self) -> bool {
708 match self.qdrant.as_ref() {
709 Some(store) => store.health_check().await,
710 None => false,
711 }
712 }
713
714 #[must_use]
716 pub fn has_vector_store(&self) -> bool {
717 self.qdrant.is_some()
718 }
719
720 pub async fn message_count(&self, conversation_id: ConversationId) -> Result<i64, MemoryError> {
726 self.sqlite.count_messages(conversation_id).await
727 }
728
729 pub async fn unsummarized_message_count(
735 &self,
736 conversation_id: ConversationId,
737 ) -> Result<i64, MemoryError> {
738 let after_id = self
739 .sqlite
740 .latest_summary_last_message_id(conversation_id)
741 .await?
742 .unwrap_or(MessageId(0));
743 self.sqlite
744 .count_messages_after(conversation_id, after_id)
745 .await
746 }
747
748 pub async fn load_summaries(
754 &self,
755 conversation_id: ConversationId,
756 ) -> Result<Vec<Summary>, MemoryError> {
757 let rows = self.sqlite.load_summaries(conversation_id).await?;
758 let summaries = rows
759 .into_iter()
760 .map(
761 |(
762 id,
763 conversation_id,
764 content,
765 first_message_id,
766 last_message_id,
767 token_estimate,
768 )| {
769 Summary {
770 id,
771 conversation_id,
772 content,
773 first_message_id,
774 last_message_id,
775 token_estimate,
776 }
777 },
778 )
779 .collect();
780 Ok(summaries)
781 }
782
783 pub async fn summarize(
791 &self,
792 conversation_id: ConversationId,
793 message_count: usize,
794 ) -> Result<Option<i64>, MemoryError> {
795 let total = self.sqlite.count_messages(conversation_id).await?;
796
797 if total <= i64::try_from(message_count)? {
798 return Ok(None);
799 }
800
801 let after_id = self
802 .sqlite
803 .latest_summary_last_message_id(conversation_id)
804 .await?
805 .unwrap_or(MessageId(0));
806
807 let messages = self
808 .sqlite
809 .load_messages_range(conversation_id, after_id, message_count)
810 .await?;
811
812 if messages.is_empty() {
813 return Ok(None);
814 }
815
816 let prompt = build_summarization_prompt(&messages);
817 let chat_messages = vec![Message {
818 role: Role::User,
819 content: prompt,
820 parts: vec![],
821 metadata: MessageMetadata::default(),
822 }];
823
824 let structured = match self
825 .provider
826 .chat_typed_erased::<StructuredSummary>(&chat_messages)
827 .await
828 {
829 Ok(s) => s,
830 Err(e) => {
831 tracing::warn!(
832 "structured summarization failed, falling back to plain text: {e:#}"
833 );
834 let plain = self.provider.chat(&chat_messages).await?;
835 StructuredSummary {
836 summary: plain,
837 key_facts: vec![],
838 entities: vec![],
839 }
840 }
841 };
842 let summary_text = &structured.summary;
843
844 let token_estimate = i64::try_from(self.token_counter.count_tokens(summary_text))?;
845 let first_message_id = messages[0].0;
846 let last_message_id = messages[messages.len() - 1].0;
847
848 let summary_id = self
849 .sqlite
850 .save_summary(
851 conversation_id,
852 summary_text,
853 first_message_id,
854 last_message_id,
855 token_estimate,
856 )
857 .await?;
858
859 if let Some(qdrant) = &self.qdrant
860 && self.provider.supports_embeddings()
861 {
862 match self.provider.embed(summary_text).await {
863 Ok(vector) => {
864 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
866 if let Err(e) = qdrant.ensure_collection(vector_size).await {
867 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
868 } else if let Err(e) = qdrant
869 .store(
870 MessageId(summary_id),
871 conversation_id,
872 "system",
873 vector,
874 MessageKind::Summary,
875 &self.embedding_model,
876 )
877 .await
878 {
879 tracing::warn!("Failed to embed summary: {e:#}");
880 }
881 }
882 Err(e) => {
883 tracing::warn!("Failed to generate summary embedding: {e:#}");
884 }
885 }
886 }
887
888 if !structured.key_facts.is_empty() {
890 self.store_key_facts(conversation_id, summary_id, &structured.key_facts)
891 .await;
892 }
893
894 Ok(Some(summary_id))
895 }
896
897 async fn store_key_facts(
898 &self,
899 conversation_id: ConversationId,
900 source_summary_id: i64,
901 key_facts: &[String],
902 ) {
903 let Some(qdrant) = &self.qdrant else {
904 return;
905 };
906 if !self.provider.supports_embeddings() {
907 return;
908 }
909
910 let Some(first_fact) = key_facts.first() else {
911 return;
912 };
913 let first_vector = match self.provider.embed(first_fact).await {
914 Ok(v) => v,
915 Err(e) => {
916 tracing::warn!("Failed to embed key fact: {e:#}");
917 return;
918 }
919 };
920 let vector_size = u64::try_from(first_vector.len()).unwrap_or(896);
921 if let Err(e) = qdrant
922 .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
923 .await
924 {
925 tracing::warn!("Failed to ensure key_facts collection: {e:#}");
926 return;
927 }
928
929 let first_payload = serde_json::json!({
930 "conversation_id": conversation_id.0,
931 "fact_text": first_fact,
932 "source_summary_id": source_summary_id,
933 });
934 if let Err(e) = qdrant
935 .store_to_collection(KEY_FACTS_COLLECTION, first_payload, first_vector)
936 .await
937 {
938 tracing::warn!("Failed to store key fact: {e:#}");
939 }
940
941 for fact in &key_facts[1..] {
942 match self.provider.embed(fact).await {
943 Ok(vector) => {
944 let payload = serde_json::json!({
945 "conversation_id": conversation_id.0,
946 "fact_text": fact,
947 "source_summary_id": source_summary_id,
948 });
949 if let Err(e) = qdrant
950 .store_to_collection(KEY_FACTS_COLLECTION, payload, vector)
951 .await
952 {
953 tracing::warn!("Failed to store key fact: {e:#}");
954 }
955 }
956 Err(e) => {
957 tracing::warn!("Failed to embed key fact: {e:#}");
958 }
959 }
960 }
961 }
962
963 pub async fn search_key_facts(
969 &self,
970 query: &str,
971 limit: usize,
972 ) -> Result<Vec<String>, MemoryError> {
973 let Some(qdrant) = &self.qdrant else {
974 return Ok(Vec::new());
975 };
976 if !self.provider.supports_embeddings() {
977 return Ok(Vec::new());
978 }
979
980 let vector = self.provider.embed(query).await?;
981 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
982 qdrant
983 .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
984 .await?;
985
986 let points = qdrant
987 .search_collection(KEY_FACTS_COLLECTION, &vector, limit, None)
988 .await?;
989
990 let facts = points
991 .into_iter()
992 .filter_map(|p| p.payload.get("fact_text")?.as_str().map(String::from))
993 .collect();
994
995 Ok(facts)
996 }
997}
998
999#[cfg(test)]
1000mod tests {
1001 use zeph_llm::mock::MockProvider;
1002 use zeph_llm::provider::Role;
1003
1004 use super::*;
1005
1006 fn test_provider() -> AnyProvider {
1007 AnyProvider::Mock(MockProvider::default())
1008 }
1009
1010 async fn test_semantic_memory(_supports_embeddings: bool) -> SemanticMemory {
1011 let provider = test_provider();
1012 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1013
1014 SemanticMemory {
1015 sqlite,
1016 qdrant: None,
1017 provider,
1018 embedding_model: "test-model".into(),
1019 vector_weight: 0.7,
1020 keyword_weight: 0.3,
1021 temporal_decay_enabled: false,
1022 temporal_decay_half_life_days: 30,
1023 mmr_enabled: false,
1024 mmr_lambda: 0.7,
1025 token_counter: Arc::new(TokenCounter::new()),
1026 }
1027 }
1028
1029 #[tokio::test]
1030 async fn remember_saves_to_sqlite() {
1031 let memory = test_semantic_memory(false).await;
1032
1033 let cid = memory.sqlite.create_conversation().await.unwrap();
1034 let msg_id = memory.remember(cid, "user", "hello").await.unwrap();
1035
1036 assert_eq!(msg_id, MessageId(1));
1037
1038 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1039 assert_eq!(history.len(), 1);
1040 assert_eq!(history[0].role, Role::User);
1041 assert_eq!(history[0].content, "hello");
1042 }
1043
1044 #[tokio::test]
1045 async fn remember_with_parts_saves_parts_json() {
1046 let memory = test_semantic_memory(false).await;
1047 let cid = memory.sqlite.create_conversation().await.unwrap();
1048
1049 let parts_json =
1050 r#"[{"kind":"ToolOutput","tool_name":"shell","body":"hello","compacted_at":null}]"#;
1051 let (msg_id, _embedding_stored) = memory
1052 .remember_with_parts(cid, "assistant", "tool output", parts_json)
1053 .await
1054 .unwrap();
1055 assert!(msg_id > MessageId(0));
1056
1057 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1058 assert_eq!(history.len(), 1);
1059 assert_eq!(history[0].content, "tool output");
1060 }
1061
1062 #[tokio::test]
1063 async fn recall_returns_empty_without_qdrant() {
1064 let memory = test_semantic_memory(true).await;
1065
1066 let recalled = memory.recall("test", 5, None).await.unwrap();
1067 assert!(recalled.is_empty());
1068 }
1069
1070 #[tokio::test]
1071 async fn has_embedding_without_qdrant() {
1072 let memory = test_semantic_memory(true).await;
1073
1074 let has_embedding = memory.has_embedding(MessageId(1)).await.unwrap();
1075 assert!(!has_embedding);
1076 }
1077
1078 #[tokio::test]
1079 async fn embed_missing_without_qdrant() {
1080 let memory = test_semantic_memory(true).await;
1081
1082 let count = memory.embed_missing().await.unwrap();
1083 assert_eq!(count, 0);
1084 }
1085
1086 #[tokio::test]
1087 async fn sqlite_accessor() {
1088 let memory = test_semantic_memory(false).await;
1089
1090 let cid = memory.sqlite().create_conversation().await.unwrap();
1091 assert_eq!(cid, ConversationId(1));
1092
1093 memory
1094 .sqlite()
1095 .save_message(cid, "user", "test")
1096 .await
1097 .unwrap();
1098
1099 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1100 assert_eq!(history.len(), 1);
1101 }
1102
1103 #[tokio::test]
1104 async fn has_vector_store_returns_false_when_unavailable() {
1105 let memory = test_semantic_memory(false).await;
1106 assert!(!memory.has_vector_store());
1107 }
1108
1109 #[tokio::test]
1110 async fn is_vector_store_connected_returns_false_when_unavailable() {
1111 let memory = test_semantic_memory(false).await;
1112 assert!(!memory.is_vector_store_connected().await);
1113 }
1114
1115 #[tokio::test]
1116 async fn recall_returns_empty_when_embeddings_not_supported() {
1117 let memory = test_semantic_memory(false).await;
1118
1119 let recalled = memory.recall("test", 5, None).await.unwrap();
1120 assert!(recalled.is_empty());
1121 }
1122
1123 #[tokio::test]
1124 async fn embed_missing_returns_zero_when_embeddings_not_supported() {
1125 let memory = test_semantic_memory(false).await;
1126
1127 let cid = memory.sqlite().create_conversation().await.unwrap();
1128 memory
1129 .sqlite()
1130 .save_message(cid, "user", "test")
1131 .await
1132 .unwrap();
1133
1134 let count = memory.embed_missing().await.unwrap();
1135 assert_eq!(count, 0);
1136 }
1137
1138 #[tokio::test]
1139 async fn message_count_empty_conversation() {
1140 let memory = test_semantic_memory(false).await;
1141 let cid = memory.sqlite().create_conversation().await.unwrap();
1142
1143 let count = memory.message_count(cid).await.unwrap();
1144 assert_eq!(count, 0);
1145 }
1146
1147 #[tokio::test]
1148 async fn message_count_after_saves() {
1149 let memory = test_semantic_memory(false).await;
1150 let cid = memory.sqlite().create_conversation().await.unwrap();
1151
1152 memory.remember(cid, "user", "msg1").await.unwrap();
1153 memory.remember(cid, "assistant", "msg2").await.unwrap();
1154
1155 let count = memory.message_count(cid).await.unwrap();
1156 assert_eq!(count, 2);
1157 }
1158
1159 #[tokio::test]
1160 async fn unsummarized_count_decreases_after_summary() {
1161 let memory = test_semantic_memory(false).await;
1162 let cid = memory.sqlite().create_conversation().await.unwrap();
1163
1164 for i in 0..10 {
1165 memory
1166 .remember(cid, "user", &format!("msg{i}"))
1167 .await
1168 .unwrap();
1169 }
1170 assert_eq!(memory.unsummarized_message_count(cid).await.unwrap(), 10);
1171
1172 memory.summarize(cid, 5).await.unwrap();
1173
1174 assert!(memory.unsummarized_message_count(cid).await.unwrap() < 10);
1175 assert_eq!(memory.message_count(cid).await.unwrap(), 10);
1176 }
1177
1178 #[tokio::test]
1179 async fn load_summaries_empty() {
1180 let memory = test_semantic_memory(false).await;
1181 let cid = memory.sqlite().create_conversation().await.unwrap();
1182
1183 let summaries = memory.load_summaries(cid).await.unwrap();
1184 assert!(summaries.is_empty());
1185 }
1186
1187 #[tokio::test]
1188 async fn load_summaries_ordered() {
1189 let memory = test_semantic_memory(false).await;
1190 let cid = memory.sqlite().create_conversation().await.unwrap();
1191
1192 let msg_id1 = memory.remember(cid, "user", "m1").await.unwrap();
1193 let msg_id2 = memory.remember(cid, "assistant", "m2").await.unwrap();
1194 let msg_id3 = memory.remember(cid, "user", "m3").await.unwrap();
1195
1196 let s1 = memory
1197 .sqlite()
1198 .save_summary(cid, "summary1", msg_id1, msg_id2, 3)
1199 .await
1200 .unwrap();
1201 let s2 = memory
1202 .sqlite()
1203 .save_summary(cid, "summary2", msg_id2, msg_id3, 3)
1204 .await
1205 .unwrap();
1206
1207 let summaries = memory.load_summaries(cid).await.unwrap();
1208 assert_eq!(summaries.len(), 2);
1209 assert_eq!(summaries[0].id, s1);
1210 assert_eq!(summaries[0].content, "summary1");
1211 assert_eq!(summaries[1].id, s2);
1212 assert_eq!(summaries[1].content, "summary2");
1213 }
1214
1215 #[tokio::test]
1216 async fn summarize_below_threshold() {
1217 let memory = test_semantic_memory(false).await;
1218 let cid = memory.sqlite().create_conversation().await.unwrap();
1219
1220 memory.remember(cid, "user", "hello").await.unwrap();
1221
1222 let result = memory.summarize(cid, 10).await.unwrap();
1223 assert!(result.is_none());
1224 }
1225
1226 #[tokio::test]
1227 async fn summarize_stores_summary() {
1228 let memory = test_semantic_memory(false).await;
1229 let cid = memory.sqlite().create_conversation().await.unwrap();
1230
1231 for i in 0..5 {
1232 memory
1233 .remember(cid, "user", &format!("message {i}"))
1234 .await
1235 .unwrap();
1236 }
1237
1238 let summary_id = memory.summarize(cid, 3).await.unwrap();
1239 assert!(summary_id.is_some());
1240
1241 let summaries = memory.load_summaries(cid).await.unwrap();
1242 assert_eq!(summaries.len(), 1);
1243 assert_eq!(summaries[0].id, summary_id.unwrap());
1244 assert!(!summaries[0].content.is_empty());
1245 }
1246
1247 #[tokio::test]
1248 async fn summarize_respects_previous_summaries() {
1249 let memory = test_semantic_memory(false).await;
1250 let cid = memory.sqlite().create_conversation().await.unwrap();
1251
1252 for i in 0..10 {
1253 memory
1254 .remember(cid, "user", &format!("message {i}"))
1255 .await
1256 .unwrap();
1257 }
1258
1259 let s1 = memory.summarize(cid, 3).await.unwrap();
1260 assert!(s1.is_some());
1261
1262 let s2 = memory.summarize(cid, 3).await.unwrap();
1263 assert!(s2.is_some());
1264
1265 let summaries = memory.load_summaries(cid).await.unwrap();
1266 assert_eq!(summaries.len(), 2);
1267 assert!(summaries[0].last_message_id < summaries[1].first_message_id);
1268 }
1269
1270 #[tokio::test]
1271 async fn remember_multiple_messages_increments_ids() {
1272 let memory = test_semantic_memory(false).await;
1273 let cid = memory.sqlite.create_conversation().await.unwrap();
1274
1275 let id1 = memory.remember(cid, "user", "first").await.unwrap();
1276 let id2 = memory.remember(cid, "assistant", "second").await.unwrap();
1277 let id3 = memory.remember(cid, "user", "third").await.unwrap();
1278
1279 assert!(id1 < id2);
1280 assert!(id2 < id3);
1281 }
1282
1283 #[tokio::test]
1284 async fn message_count_across_conversations() {
1285 let memory = test_semantic_memory(false).await;
1286 let cid1 = memory.sqlite().create_conversation().await.unwrap();
1287 let cid2 = memory.sqlite().create_conversation().await.unwrap();
1288
1289 memory.remember(cid1, "user", "msg1").await.unwrap();
1290 memory.remember(cid1, "user", "msg2").await.unwrap();
1291 memory.remember(cid2, "user", "msg3").await.unwrap();
1292
1293 assert_eq!(memory.message_count(cid1).await.unwrap(), 2);
1294 assert_eq!(memory.message_count(cid2).await.unwrap(), 1);
1295 }
1296
1297 #[tokio::test]
1298 async fn summarize_exact_threshold_returns_none() {
1299 let memory = test_semantic_memory(false).await;
1300 let cid = memory.sqlite().create_conversation().await.unwrap();
1301
1302 for i in 0..3 {
1303 memory
1304 .remember(cid, "user", &format!("msg {i}"))
1305 .await
1306 .unwrap();
1307 }
1308
1309 let result = memory.summarize(cid, 3).await.unwrap();
1310 assert!(result.is_none());
1311 }
1312
1313 #[tokio::test]
1314 async fn summarize_one_above_threshold_produces_summary() {
1315 let memory = test_semantic_memory(false).await;
1316 let cid = memory.sqlite().create_conversation().await.unwrap();
1317
1318 for i in 0..4 {
1319 memory
1320 .remember(cid, "user", &format!("msg {i}"))
1321 .await
1322 .unwrap();
1323 }
1324
1325 let result = memory.summarize(cid, 3).await.unwrap();
1326 assert!(result.is_some());
1327 }
1328
1329 #[tokio::test]
1330 async fn summary_fields_populated() {
1331 let memory = test_semantic_memory(false).await;
1332 let cid = memory.sqlite().create_conversation().await.unwrap();
1333
1334 for i in 0..5 {
1335 memory
1336 .remember(cid, "user", &format!("msg {i}"))
1337 .await
1338 .unwrap();
1339 }
1340
1341 memory.summarize(cid, 3).await.unwrap();
1342 let summaries = memory.load_summaries(cid).await.unwrap();
1343 let s = &summaries[0];
1344
1345 assert_eq!(s.conversation_id, cid);
1346 assert!(s.first_message_id > MessageId(0));
1347 assert!(s.last_message_id >= s.first_message_id);
1348 assert!(s.token_estimate >= 0);
1349 assert!(!s.content.is_empty());
1350 }
1351
1352 #[test]
1353 fn build_summarization_prompt_format() {
1354 let messages = vec![
1355 (MessageId(1), "user".into(), "Hello".into()),
1356 (MessageId(2), "assistant".into(), "Hi there".into()),
1357 ];
1358 let prompt = build_summarization_prompt(&messages);
1359 assert!(prompt.contains("user: Hello"));
1360 assert!(prompt.contains("assistant: Hi there"));
1361 assert!(prompt.contains("key_facts"));
1362 }
1363
1364 #[test]
1365 fn build_summarization_prompt_empty() {
1366 let messages: Vec<(MessageId, String, String)> = vec![];
1367 let prompt = build_summarization_prompt(&messages);
1368 assert!(prompt.contains("key_facts"));
1369 }
1370
1371 #[test]
1372 fn structured_summary_deserialize() {
1373 let json = r#"{"summary":"s","key_facts":["f1","f2"],"entities":["e1"]}"#;
1374 let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1375 assert_eq!(ss.summary, "s");
1376 assert_eq!(ss.key_facts.len(), 2);
1377 assert_eq!(ss.entities.len(), 1);
1378 }
1379
1380 #[test]
1381 fn structured_summary_empty_facts() {
1382 let json = r#"{"summary":"s","key_facts":[],"entities":[]}"#;
1383 let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1384 assert!(ss.key_facts.is_empty());
1385 assert!(ss.entities.is_empty());
1386 }
1387
1388 #[tokio::test]
1389 async fn search_key_facts_no_qdrant_empty() {
1390 let memory = test_semantic_memory(false).await;
1391 let facts = memory.search_key_facts("query", 5).await.unwrap();
1392 assert!(facts.is_empty());
1393 }
1394
1395 #[test]
1396 fn recalled_message_debug() {
1397 let recalled = RecalledMessage {
1398 message: Message {
1399 role: Role::User,
1400 content: "test".into(),
1401 parts: vec![],
1402 metadata: MessageMetadata::default(),
1403 },
1404 score: 0.95,
1405 };
1406 let dbg = format!("{recalled:?}");
1407 assert!(dbg.contains("RecalledMessage"));
1408 assert!(dbg.contains("0.95"));
1409 }
1410
1411 #[test]
1412 fn summary_clone() {
1413 let summary = Summary {
1414 id: 1,
1415 conversation_id: ConversationId(2),
1416 content: "test summary".into(),
1417 first_message_id: MessageId(1),
1418 last_message_id: MessageId(5),
1419 token_estimate: 10,
1420 };
1421 let cloned = summary.clone();
1422 assert_eq!(summary.id, cloned.id);
1423 assert_eq!(summary.content, cloned.content);
1424 }
1425
1426 #[tokio::test]
1427 async fn remember_preserves_role_mapping() {
1428 let memory = test_semantic_memory(false).await;
1429 let cid = memory.sqlite.create_conversation().await.unwrap();
1430
1431 memory.remember(cid, "user", "u").await.unwrap();
1432 memory.remember(cid, "assistant", "a").await.unwrap();
1433 memory.remember(cid, "system", "s").await.unwrap();
1434
1435 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1436 assert_eq!(history.len(), 3);
1437 assert_eq!(history[0].role, Role::User);
1438 assert_eq!(history[1].role, Role::Assistant);
1439 assert_eq!(history[2].role, Role::System);
1440 }
1441
1442 #[tokio::test]
1443 async fn new_with_invalid_qdrant_url_graceful() {
1444 let mut mock = MockProvider::default();
1445 mock.supports_embeddings = true;
1446 let provider = AnyProvider::Mock(mock);
1447 let result =
1448 SemanticMemory::new(":memory:", "http://127.0.0.1:1", provider, "test-model").await;
1449 assert!(result.is_ok());
1450 }
1451
1452 #[tokio::test]
1453 async fn test_semantic_memory_sqlite_remember_recall_roundtrip() {
1454 let mut mock = MockProvider::default();
1456 mock.supports_embeddings = true;
1457 let provider = AnyProvider::Mock(mock);
1460
1461 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1462 let pool = sqlite.pool().clone();
1463 let qdrant = Some(crate::embedding_store::EmbeddingStore::new_sqlite(pool));
1464
1465 let memory = SemanticMemory {
1466 sqlite,
1467 qdrant,
1468 provider,
1469 embedding_model: "test-model".into(),
1470 vector_weight: 0.7,
1471 keyword_weight: 0.3,
1472 temporal_decay_enabled: false,
1473 temporal_decay_half_life_days: 30,
1474 mmr_enabled: false,
1475 mmr_lambda: 0.7,
1476 token_counter: Arc::new(TokenCounter::new()),
1477 };
1478
1479 let cid = memory.sqlite().create_conversation().await.unwrap();
1480
1481 let id1 = memory
1483 .remember(cid, "user", "rust async programming")
1484 .await
1485 .unwrap();
1486 let id2 = memory
1487 .remember(cid, "assistant", "use tokio for async")
1488 .await
1489 .unwrap();
1490 assert!(id1 < id2);
1491
1492 let recalled = memory.recall("rust", 5, None).await.unwrap();
1494 assert!(
1495 !recalled.is_empty(),
1496 "recall must return at least one result"
1497 );
1498
1499 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1501 assert_eq!(history.len(), 2);
1502 assert_eq!(history[0].content, "rust async programming");
1503 }
1504
1505 #[tokio::test]
1506 async fn remember_with_embeddings_supported_but_no_qdrant() {
1507 let memory = test_semantic_memory(true).await;
1508 let cid = memory.sqlite.create_conversation().await.unwrap();
1509
1510 let msg_id = memory.remember(cid, "user", "hello embed").await.unwrap();
1511 assert!(msg_id > MessageId(0));
1512
1513 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1514 assert_eq!(history.len(), 1);
1515 assert_eq!(history[0].content, "hello embed");
1516 }
1517
1518 #[tokio::test]
1519 async fn remember_verifies_content_via_load_history() {
1520 let memory = test_semantic_memory(false).await;
1521 let cid = memory.sqlite.create_conversation().await.unwrap();
1522
1523 memory.remember(cid, "user", "alpha").await.unwrap();
1524 memory.remember(cid, "assistant", "beta").await.unwrap();
1525 memory.remember(cid, "user", "gamma").await.unwrap();
1526
1527 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1528 assert_eq!(history.len(), 3);
1529 assert_eq!(history[0].content, "alpha");
1530 assert_eq!(history[1].content, "beta");
1531 assert_eq!(history[2].content, "gamma");
1532 }
1533
1534 #[tokio::test]
1535 async fn message_count_multiple_conversations_isolated() {
1536 let memory = test_semantic_memory(false).await;
1537 let cid1 = memory.sqlite().create_conversation().await.unwrap();
1538 let cid2 = memory.sqlite().create_conversation().await.unwrap();
1539 let cid3 = memory.sqlite().create_conversation().await.unwrap();
1540
1541 for _ in 0..5 {
1542 memory.remember(cid1, "user", "msg").await.unwrap();
1543 }
1544 for _ in 0..3 {
1545 memory.remember(cid2, "user", "msg").await.unwrap();
1546 }
1547
1548 assert_eq!(memory.message_count(cid1).await.unwrap(), 5);
1549 assert_eq!(memory.message_count(cid2).await.unwrap(), 3);
1550 assert_eq!(memory.message_count(cid3).await.unwrap(), 0);
1551 }
1552
1553 #[tokio::test]
1554 async fn summarize_empty_messages_range_returns_none() {
1555 let memory = test_semantic_memory(false).await;
1556 let cid = memory.sqlite().create_conversation().await.unwrap();
1557
1558 for i in 0..6 {
1559 memory
1560 .remember(cid, "user", &format!("msg {i}"))
1561 .await
1562 .unwrap();
1563 }
1564
1565 memory.summarize(cid, 3).await.unwrap();
1566 memory.summarize(cid, 3).await.unwrap();
1567
1568 let summaries = memory.load_summaries(cid).await.unwrap();
1569 assert_eq!(summaries.len(), 2);
1570 }
1571
1572 #[tokio::test]
1573 async fn summarize_token_estimate_populated() {
1574 let memory = test_semantic_memory(false).await;
1575 let cid = memory.sqlite().create_conversation().await.unwrap();
1576
1577 for i in 0..5 {
1578 memory
1579 .remember(cid, "user", &format!("message {i}"))
1580 .await
1581 .unwrap();
1582 }
1583
1584 memory.summarize(cid, 3).await.unwrap();
1585 let summaries = memory.load_summaries(cid).await.unwrap();
1586 let token_est = summaries[0].token_estimate;
1587 assert!(token_est > 0);
1588 }
1589
1590 #[tokio::test]
1591 async fn summarize_fails_when_provider_chat_fails() {
1592 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1593 let provider = AnyProvider::Ollama(zeph_llm::ollama::OllamaProvider::new(
1594 "http://127.0.0.1:1",
1595 "test".into(),
1596 "embed".into(),
1597 ));
1598 let memory = SemanticMemory {
1599 sqlite,
1600 qdrant: None,
1601 provider,
1602 embedding_model: "test".into(),
1603 vector_weight: 0.7,
1604 keyword_weight: 0.3,
1605 temporal_decay_enabled: false,
1606 temporal_decay_half_life_days: 30,
1607 mmr_enabled: false,
1608 mmr_lambda: 0.7,
1609 token_counter: Arc::new(TokenCounter::new()),
1610 };
1611 let cid = memory.sqlite().create_conversation().await.unwrap();
1612
1613 for i in 0..5 {
1614 memory
1615 .remember(cid, "user", &format!("msg {i}"))
1616 .await
1617 .unwrap();
1618 }
1619
1620 let result = memory.summarize(cid, 3).await;
1621 assert!(result.is_err());
1622 }
1623
1624 #[tokio::test]
1625 async fn embed_missing_without_embedding_support_returns_zero() {
1626 let memory = test_semantic_memory(false).await;
1627 let cid = memory.sqlite().create_conversation().await.unwrap();
1628 memory
1629 .sqlite()
1630 .save_message(cid, "user", "test message")
1631 .await
1632 .unwrap();
1633
1634 let count = memory.embed_missing().await.unwrap();
1635 assert_eq!(count, 0);
1636 }
1637
1638 #[tokio::test]
1639 async fn has_embedding_returns_false_when_no_qdrant() {
1640 let memory = test_semantic_memory(false).await;
1641 let cid = memory.sqlite.create_conversation().await.unwrap();
1642 let msg_id = memory.remember(cid, "user", "test").await.unwrap();
1643 assert!(!memory.has_embedding(msg_id).await.unwrap());
1644 }
1645
1646 #[tokio::test]
1647 async fn recall_empty_without_qdrant_regardless_of_filter() {
1648 let memory = test_semantic_memory(true).await;
1649 let filter = SearchFilter {
1650 conversation_id: Some(ConversationId(1)),
1651 role: None,
1652 };
1653 let recalled = memory.recall("query", 10, Some(filter)).await.unwrap();
1654 assert!(recalled.is_empty());
1655 }
1656
1657 #[tokio::test]
1658 async fn summarize_message_range_bounds() {
1659 let memory = test_semantic_memory(false).await;
1660 let cid = memory.sqlite().create_conversation().await.unwrap();
1661
1662 for i in 0..8 {
1663 memory
1664 .remember(cid, "user", &format!("msg {i}"))
1665 .await
1666 .unwrap();
1667 }
1668
1669 let summary_id = memory.summarize(cid, 4).await.unwrap().unwrap();
1670 let summaries = memory.load_summaries(cid).await.unwrap();
1671 assert_eq!(summaries.len(), 1);
1672 assert_eq!(summaries[0].id, summary_id);
1673 assert!(summaries[0].first_message_id >= MessageId(1));
1674 assert!(summaries[0].last_message_id >= summaries[0].first_message_id);
1675 }
1676
1677 #[test]
1678 fn build_summarization_prompt_preserves_order() {
1679 let messages = vec![
1680 (MessageId(1), "user".into(), "first".into()),
1681 (MessageId(2), "assistant".into(), "second".into()),
1682 (MessageId(3), "user".into(), "third".into()),
1683 ];
1684 let prompt = build_summarization_prompt(&messages);
1685 let first_pos = prompt.find("user: first").unwrap();
1686 let second_pos = prompt.find("assistant: second").unwrap();
1687 let third_pos = prompt.find("user: third").unwrap();
1688 assert!(first_pos < second_pos);
1689 assert!(second_pos < third_pos);
1690 }
1691
1692 #[test]
1693 fn summary_debug() {
1694 let summary = Summary {
1695 id: 1,
1696 conversation_id: ConversationId(2),
1697 content: "test".into(),
1698 first_message_id: MessageId(1),
1699 last_message_id: MessageId(5),
1700 token_estimate: 10,
1701 };
1702 let dbg = format!("{summary:?}");
1703 assert!(dbg.contains("Summary"));
1704 }
1705
1706 #[tokio::test]
1707 async fn message_count_nonexistent_conversation() {
1708 let memory = test_semantic_memory(false).await;
1709 let count = memory.message_count(ConversationId(999)).await.unwrap();
1710 assert_eq!(count, 0);
1711 }
1712
1713 #[tokio::test]
1714 async fn load_summaries_nonexistent_conversation() {
1715 let memory = test_semantic_memory(false).await;
1716 let summaries = memory.load_summaries(ConversationId(999)).await.unwrap();
1717 assert!(summaries.is_empty());
1718 }
1719
1720 #[tokio::test]
1721 async fn store_session_summary_no_qdrant_noop() {
1722 let memory = test_semantic_memory(true).await;
1723 let result = memory
1724 .store_session_summary(ConversationId(1), "test summary")
1725 .await;
1726 assert!(result.is_ok());
1727 }
1728
1729 #[tokio::test]
1730 async fn store_session_summary_no_embeddings_noop() {
1731 let memory = test_semantic_memory(false).await;
1732 let result = memory
1733 .store_session_summary(ConversationId(1), "test summary")
1734 .await;
1735 assert!(result.is_ok());
1736 }
1737
1738 #[tokio::test]
1739 async fn search_session_summaries_no_qdrant_empty() {
1740 let memory = test_semantic_memory(true).await;
1741 let results = memory
1742 .search_session_summaries("query", 5, None)
1743 .await
1744 .unwrap();
1745 assert!(results.is_empty());
1746 }
1747
1748 #[tokio::test]
1749 async fn search_session_summaries_no_embeddings_empty() {
1750 let memory = test_semantic_memory(false).await;
1751 let results = memory
1752 .search_session_summaries("query", 5, Some(ConversationId(1)))
1753 .await
1754 .unwrap();
1755 assert!(results.is_empty());
1756 }
1757
1758 #[test]
1759 fn session_summary_result_debug() {
1760 let result = SessionSummaryResult {
1761 summary_text: "test".into(),
1762 score: 0.9,
1763 conversation_id: ConversationId(1),
1764 };
1765 let dbg = format!("{result:?}");
1766 assert!(dbg.contains("SessionSummaryResult"));
1767 }
1768
1769 #[test]
1770 fn session_summary_result_clone() {
1771 let result = SessionSummaryResult {
1772 summary_text: "test".into(),
1773 score: 0.9,
1774 conversation_id: ConversationId(1),
1775 };
1776 let cloned = result.clone();
1777 assert_eq!(result.summary_text, cloned.summary_text);
1778 assert_eq!(result.conversation_id, cloned.conversation_id);
1779 }
1780
1781 #[tokio::test]
1782 async fn recall_fts5_fallback_without_qdrant() {
1783 let memory = test_semantic_memory(false).await;
1784 let cid = memory.sqlite.create_conversation().await.unwrap();
1785
1786 memory
1787 .remember(cid, "user", "rust programming guide")
1788 .await
1789 .unwrap();
1790 memory
1791 .remember(cid, "assistant", "python tutorial")
1792 .await
1793 .unwrap();
1794 memory
1795 .remember(cid, "user", "advanced rust patterns")
1796 .await
1797 .unwrap();
1798
1799 let recalled = memory.recall("rust", 5, None).await.unwrap();
1800 assert_eq!(recalled.len(), 2);
1801 assert!(recalled[0].score >= recalled[1].score);
1802 }
1803
1804 #[tokio::test]
1805 async fn recall_fts5_fallback_with_filter() {
1806 let memory = test_semantic_memory(false).await;
1807 let cid1 = memory.sqlite.create_conversation().await.unwrap();
1808 let cid2 = memory.sqlite.create_conversation().await.unwrap();
1809
1810 memory.remember(cid1, "user", "hello world").await.unwrap();
1811 memory
1812 .remember(cid2, "user", "hello universe")
1813 .await
1814 .unwrap();
1815
1816 let filter = SearchFilter {
1817 conversation_id: Some(cid1),
1818 role: None,
1819 };
1820 let recalled = memory.recall("hello", 5, Some(filter)).await.unwrap();
1821 assert_eq!(recalled.len(), 1);
1822 }
1823
1824 #[tokio::test]
1825 async fn recall_fts5_no_matches_returns_empty() {
1826 let memory = test_semantic_memory(false).await;
1827 let cid = memory.sqlite.create_conversation().await.unwrap();
1828
1829 memory.remember(cid, "user", "hello world").await.unwrap();
1830
1831 let recalled = memory.recall("nonexistent", 5, None).await.unwrap();
1832 assert!(recalled.is_empty());
1833 }
1834
1835 #[tokio::test]
1836 async fn recall_fts5_respects_limit() {
1837 let memory = test_semantic_memory(false).await;
1838 let cid = memory.sqlite.create_conversation().await.unwrap();
1839
1840 for i in 0..10 {
1841 memory
1842 .remember(cid, "user", &format!("test message number {i}"))
1843 .await
1844 .unwrap();
1845 }
1846
1847 let recalled = memory.recall("test", 3, None).await.unwrap();
1848 assert_eq!(recalled.len(), 3);
1849 }
1850
1851 #[tokio::test]
1854 async fn summarize_fallback_to_plain_text_when_structured_fails() {
1855 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1863 let mut mock = MockProvider::default();
1864 mock.default_response = "plain text summary".into();
1866 let provider = AnyProvider::Mock(mock);
1867
1868 let memory = SemanticMemory {
1869 sqlite,
1870 qdrant: None,
1871 provider,
1872 embedding_model: "test".into(),
1873 vector_weight: 0.7,
1874 keyword_weight: 0.3,
1875 temporal_decay_enabled: false,
1876 temporal_decay_half_life_days: 30,
1877 mmr_enabled: false,
1878 mmr_lambda: 0.7,
1879 token_counter: Arc::new(TokenCounter::new()),
1880 };
1881
1882 let cid = memory.sqlite().create_conversation().await.unwrap();
1883 for i in 0..5 {
1884 memory
1885 .remember(cid, "user", &format!("msg {i}"))
1886 .await
1887 .unwrap();
1888 }
1889
1890 let result = memory.summarize(cid, 3).await;
1891 assert!(result.is_ok());
1897 let summaries = memory.load_summaries(cid).await.unwrap();
1898 assert_eq!(summaries.len(), 1);
1899 assert!(!summaries[0].content.is_empty());
1900 }
1901
1902 #[test]
1905 fn temporal_decay_disabled_leaves_scores_unchanged() {
1906 let mut ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
1907 let timestamps = std::collections::HashMap::new();
1908 apply_temporal_decay(&mut ranked, ×tamps, 30);
1909 assert!((ranked[0].1 - 1.0).abs() < f64::EPSILON);
1910 assert!((ranked[1].1 - 0.5).abs() < f64::EPSILON);
1911 }
1912
1913 #[test]
1914 fn temporal_decay_zero_age_preserves_score() {
1915 let now = std::time::SystemTime::now()
1916 .duration_since(std::time::UNIX_EPOCH)
1917 .unwrap_or_default()
1918 .as_secs()
1919 .cast_signed();
1920 let mut ranked = vec![(MessageId(1), 1.0f64)];
1921 let mut timestamps = std::collections::HashMap::new();
1922 timestamps.insert(MessageId(1), now);
1923 apply_temporal_decay(&mut ranked, ×tamps, 30);
1924 assert!((ranked[0].1 - 1.0).abs() < 0.01);
1926 }
1927
1928 #[test]
1929 fn temporal_decay_half_life_halves_score() {
1930 let half_life = 30u32;
1932 let age_secs = i64::from(half_life) * 86400;
1933 let now = std::time::SystemTime::now()
1934 .duration_since(std::time::UNIX_EPOCH)
1935 .unwrap_or_default()
1936 .as_secs()
1937 .cast_signed();
1938 let ts = now - age_secs;
1939 let mut ranked = vec![(MessageId(1), 1.0f64)];
1940 let mut timestamps = std::collections::HashMap::new();
1941 timestamps.insert(MessageId(1), ts);
1942 apply_temporal_decay(&mut ranked, ×tamps, half_life);
1943 assert!(
1945 (ranked[0].1 - 0.5).abs() < 0.01,
1946 "score was {}",
1947 ranked[0].1
1948 );
1949 }
1950
1951 #[test]
1954 fn mmr_empty_input_returns_empty() {
1955 let ranked = vec![];
1956 let vectors = std::collections::HashMap::new();
1957 let result = apply_mmr(&ranked, &vectors, 0.7, 5);
1958 assert!(result.is_empty());
1959 }
1960
1961 #[test]
1962 fn mmr_returns_up_to_limit() {
1963 let ranked = vec![
1964 (MessageId(1), 1.0f64),
1965 (MessageId(2), 0.9f64),
1966 (MessageId(3), 0.8f64),
1967 ];
1968 let mut vectors = std::collections::HashMap::new();
1969 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
1970 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
1971 vectors.insert(MessageId(3), vec![1.0f32, 0.0]);
1972 let result = apply_mmr(&ranked, &vectors, 0.7, 2);
1973 assert_eq!(result.len(), 2);
1974 }
1975
1976 #[test]
1977 fn mmr_without_vectors_picks_by_relevance() {
1978 let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
1979 let vectors = std::collections::HashMap::new();
1980 let result = apply_mmr(&ranked, &vectors, 0.7, 2);
1981 assert_eq!(result.len(), 2);
1982 assert_eq!(result[0].0, MessageId(1));
1983 }
1984
1985 #[test]
1986 fn mmr_prefers_diverse_over_redundant() {
1987 let ranked = vec![
1989 (MessageId(1), 1.0f64), (MessageId(2), 0.9f64), (MessageId(3), 0.9f64), ];
1993 let mut vectors = std::collections::HashMap::new();
1994 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
1995 vectors.insert(MessageId(2), vec![0.0f32, 1.0]); vectors.insert(MessageId(3), vec![1.0f32, 0.0]); let result = apply_mmr(&ranked, &vectors, 0.5, 2);
1998 assert_eq!(result.len(), 2);
1999 assert_eq!(result[0].0, MessageId(1));
2000 assert_eq!(result[1].0, MessageId(2));
2002 }
2003
2004 #[test]
2005 fn temporal_decay_half_life_zero_is_noop() {
2006 let now = std::time::SystemTime::now()
2007 .duration_since(std::time::UNIX_EPOCH)
2008 .unwrap_or_default()
2009 .as_secs()
2010 .cast_signed();
2011 let age_secs = 30i64 * 86400;
2012 let ts = now - age_secs;
2013 let mut ranked = vec![(MessageId(1), 1.0f64)];
2014 let mut timestamps = std::collections::HashMap::new();
2015 timestamps.insert(MessageId(1), ts);
2016 apply_temporal_decay(&mut ranked, ×tamps, 0);
2018 assert!(
2019 (ranked[0].1 - 1.0).abs() < f64::EPSILON,
2020 "score was {}",
2021 ranked[0].1
2022 );
2023 }
2024
2025 #[test]
2026 fn temporal_decay_huge_age_near_zero() {
2027 let now = std::time::SystemTime::now()
2028 .duration_since(std::time::UNIX_EPOCH)
2029 .unwrap_or_default()
2030 .as_secs()
2031 .cast_signed();
2032 let age_secs = 3650i64 * 86400;
2034 let ts = now - age_secs;
2035 let mut ranked = vec![(MessageId(1), 1.0f64)];
2036 let mut timestamps = std::collections::HashMap::new();
2037 timestamps.insert(MessageId(1), ts);
2038 apply_temporal_decay(&mut ranked, ×tamps, 30);
2039 assert!(ranked[0].1 < 0.001, "score was {}", ranked[0].1);
2041 }
2042
2043 #[test]
2044 fn temporal_decay_small_half_life() {
2045 let now = std::time::SystemTime::now()
2047 .duration_since(std::time::UNIX_EPOCH)
2048 .unwrap_or_default()
2049 .as_secs()
2050 .cast_signed();
2051 let ts = now - 7 * 86400i64;
2052 let mut ranked = vec![(MessageId(1), 1.0f64)];
2053 let mut timestamps = std::collections::HashMap::new();
2054 timestamps.insert(MessageId(1), ts);
2055 apply_temporal_decay(&mut ranked, ×tamps, 1);
2056 assert!(ranked[0].1 < 0.01, "score was {}", ranked[0].1);
2057 }
2058
2059 #[test]
2060 fn mmr_lambda_zero_max_diversity() {
2061 let ranked = vec![
2063 (MessageId(1), 1.0f64), (MessageId(2), 0.9f64), (MessageId(3), 0.85f64), ];
2067 let mut vectors = std::collections::HashMap::new();
2068 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2069 vectors.insert(MessageId(2), vec![0.0f32, 1.0]); vectors.insert(MessageId(3), vec![1.0f32, 0.0]); let result = apply_mmr(&ranked, &vectors, 0.0, 3);
2072 assert_eq!(result.len(), 3);
2073 assert_eq!(result[1].0, MessageId(2));
2075 }
2076
2077 #[test]
2078 fn mmr_lambda_one_pure_relevance() {
2079 let ranked = vec![
2081 (MessageId(1), 1.0f64),
2082 (MessageId(2), 0.8f64),
2083 (MessageId(3), 0.6f64),
2084 ];
2085 let mut vectors = std::collections::HashMap::new();
2086 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2087 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2088 vectors.insert(MessageId(3), vec![0.5f32, 0.5]);
2089 let result = apply_mmr(&ranked, &vectors, 1.0, 3);
2090 assert_eq!(result.len(), 3);
2091 assert_eq!(result[0].0, MessageId(1));
2092 assert_eq!(result[1].0, MessageId(2));
2093 assert_eq!(result[2].0, MessageId(3));
2094 }
2095
2096 #[test]
2097 fn mmr_limit_zero_returns_empty() {
2098 let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.8f64)];
2099 let mut vectors = std::collections::HashMap::new();
2100 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2101 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2102 let result = apply_mmr(&ranked, &vectors, 0.7, 0);
2103 assert!(result.is_empty());
2104 }
2105
2106 #[test]
2107 fn mmr_duplicate_vectors_penalizes_second() {
2108 let ranked = vec![
2110 (MessageId(1), 1.0f64),
2111 (MessageId(2), 1.0f64), (MessageId(3), 0.9f64), ];
2114 let mut vectors = std::collections::HashMap::new();
2115 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2116 vectors.insert(MessageId(2), vec![1.0f32, 0.0]); vectors.insert(MessageId(3), vec![0.0f32, 1.0]); let result = apply_mmr(&ranked, &vectors, 0.5, 3);
2119 assert_eq!(result.len(), 3);
2120 assert_eq!(result[0].0, MessageId(1));
2121 assert_eq!(result[1].0, MessageId(3));
2123 }
2124
2125 use proptest::prelude::*;
2128
2129 proptest! {
2130 #[test]
2131 fn count_tokens_never_panics(s in ".*") {
2132 let counter = crate::token_counter::TokenCounter::new();
2133 let _ = counter.count_tokens(&s);
2134 }
2135 }
2136}