1use zeph_llm::any::AnyProvider;
2use zeph_llm::provider::{LlmProvider, Message, Role};
3
4use crate::embedding_store::{EmbeddingStore, MessageKind, SearchFilter};
5use crate::error::MemoryError;
6use crate::sqlite::SqliteStore;
7use crate::types::{ConversationId, MessageId};
8use crate::vector_store::{FieldCondition, FieldValue, VectorFilter};
9
10const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
11const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
12
13#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
14pub struct StructuredSummary {
15 pub summary: String,
16 pub key_facts: Vec<String>,
17 pub entities: Vec<String>,
18}
19
20#[derive(Debug)]
21pub struct RecalledMessage {
22 pub message: Message,
23 pub score: f32,
24}
25
26#[derive(Debug, Clone)]
27pub struct Summary {
28 pub id: i64,
29 pub conversation_id: ConversationId,
30 pub content: String,
31 pub first_message_id: MessageId,
32 pub last_message_id: MessageId,
33 pub token_estimate: i64,
34}
35
36#[derive(Debug, Clone)]
37pub struct SessionSummaryResult {
38 pub summary_text: String,
39 pub score: f32,
40 pub conversation_id: ConversationId,
41}
42
43#[must_use]
48pub fn estimate_tokens(text: &str) -> usize {
49 text.chars().count() / 4
50}
51
52fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
53 if a.len() != b.len() || a.is_empty() {
54 return 0.0;
55 }
56 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
57 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
58 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
59 if norm_a == 0.0 || norm_b == 0.0 {
60 return 0.0;
61 }
62 dot / (norm_a * norm_b)
63}
64
65fn apply_temporal_decay(
66 ranked: &mut [(MessageId, f64)],
67 timestamps: &std::collections::HashMap<MessageId, i64>,
68 half_life_days: u32,
69) {
70 if half_life_days == 0 {
71 return;
72 }
73 let now = std::time::SystemTime::now()
74 .duration_since(std::time::UNIX_EPOCH)
75 .unwrap_or_default()
76 .as_secs()
77 .cast_signed();
78 let lambda = std::f64::consts::LN_2 / f64::from(half_life_days);
79
80 for (msg_id, score) in ranked.iter_mut() {
81 if let Some(&ts) = timestamps.get(msg_id) {
82 #[allow(clippy::cast_precision_loss)]
83 let age_days = (now - ts).max(0) as f64 / 86400.0;
84 *score *= (-lambda * age_days).exp();
85 }
86 }
87}
88
89fn apply_mmr(
90 ranked: &[(MessageId, f64)],
91 vectors: &std::collections::HashMap<MessageId, Vec<f32>>,
92 lambda: f32,
93 limit: usize,
94) -> Vec<(MessageId, f64)> {
95 if ranked.is_empty() || limit == 0 {
96 return Vec::new();
97 }
98
99 let lambda = f64::from(lambda);
100 let mut selected: Vec<(MessageId, f64)> = Vec::with_capacity(limit);
101 let mut remaining: Vec<(MessageId, f64)> = ranked.to_vec();
102
103 while selected.len() < limit && !remaining.is_empty() {
104 let best_idx = if selected.is_empty() {
105 0
107 } else {
108 let mut best = 0usize;
109 let mut best_score = f64::NEG_INFINITY;
110
111 for (i, &(cand_id, relevance)) in remaining.iter().enumerate() {
112 let max_sim = if let Some(cand_vec) = vectors.get(&cand_id) {
113 selected
114 .iter()
115 .filter_map(|(sel_id, _)| vectors.get(sel_id))
116 .map(|sel_vec| f64::from(cosine_similarity(cand_vec, sel_vec)))
117 .fold(f64::NEG_INFINITY, f64::max)
118 } else {
119 0.0
120 };
121 let max_sim = if max_sim == f64::NEG_INFINITY {
122 0.0
123 } else {
124 max_sim
125 };
126 let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim;
127 if mmr_score > best_score {
128 best_score = mmr_score;
129 best = i;
130 }
131 }
132 best
133 };
134
135 selected.push(remaining.remove(best_idx));
136 }
137
138 selected
139}
140
141fn build_summarization_prompt(messages: &[(MessageId, String, String)]) -> String {
142 let mut prompt = String::from(
143 "Summarize the following conversation. Extract key facts, decisions, entities, \
144 and context needed to continue the conversation.\n\n\
145 Respond in JSON with fields: summary (string), key_facts (list of strings), \
146 entities (list of strings).\n\nConversation:\n",
147 );
148
149 for (_, role, content) in messages {
150 prompt.push_str(role);
151 prompt.push_str(": ");
152 prompt.push_str(content);
153 prompt.push('\n');
154 }
155
156 prompt
157}
158
159pub struct SemanticMemory {
160 sqlite: SqliteStore,
161 qdrant: Option<EmbeddingStore>,
162 provider: AnyProvider,
163 embedding_model: String,
164 vector_weight: f64,
165 keyword_weight: f64,
166 temporal_decay_enabled: bool,
167 temporal_decay_half_life_days: u32,
168 mmr_enabled: bool,
169 mmr_lambda: f32,
170}
171
172impl SemanticMemory {
173 pub async fn new(
181 sqlite_path: &str,
182 qdrant_url: &str,
183 provider: AnyProvider,
184 embedding_model: &str,
185 ) -> Result<Self, MemoryError> {
186 Self::with_weights(sqlite_path, qdrant_url, provider, embedding_model, 0.7, 0.3).await
187 }
188
189 pub async fn with_weights(
195 sqlite_path: &str,
196 qdrant_url: &str,
197 provider: AnyProvider,
198 embedding_model: &str,
199 vector_weight: f64,
200 keyword_weight: f64,
201 ) -> Result<Self, MemoryError> {
202 let sqlite = SqliteStore::new(sqlite_path).await?;
203 let pool = sqlite.pool().clone();
204
205 let qdrant = match EmbeddingStore::new(qdrant_url, pool) {
206 Ok(store) => Some(store),
207 Err(e) => {
208 tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
209 None
210 }
211 };
212
213 Ok(Self {
214 sqlite,
215 qdrant,
216 provider,
217 embedding_model: embedding_model.into(),
218 vector_weight,
219 keyword_weight,
220 temporal_decay_enabled: false,
221 temporal_decay_half_life_days: 30,
222 mmr_enabled: false,
223 mmr_lambda: 0.7,
224 })
225 }
226
227 #[must_use]
229 pub fn with_ranking_options(
230 mut self,
231 temporal_decay_enabled: bool,
232 temporal_decay_half_life_days: u32,
233 mmr_enabled: bool,
234 mmr_lambda: f32,
235 ) -> Self {
236 self.temporal_decay_enabled = temporal_decay_enabled;
237 self.temporal_decay_half_life_days = temporal_decay_half_life_days;
238 self.mmr_enabled = mmr_enabled;
239 self.mmr_lambda = mmr_lambda;
240 self
241 }
242
243 pub async fn with_sqlite_backend(
249 sqlite_path: &str,
250 provider: AnyProvider,
251 embedding_model: &str,
252 vector_weight: f64,
253 keyword_weight: f64,
254 ) -> Result<Self, MemoryError> {
255 let sqlite = SqliteStore::new(sqlite_path).await?;
256 let pool = sqlite.pool().clone();
257 let store = EmbeddingStore::new_sqlite(pool);
258
259 Ok(Self {
260 sqlite,
261 qdrant: Some(store),
262 provider,
263 embedding_model: embedding_model.into(),
264 vector_weight,
265 keyword_weight,
266 temporal_decay_enabled: false,
267 temporal_decay_half_life_days: 30,
268 mmr_enabled: false,
269 mmr_lambda: 0.7,
270 })
271 }
272
273 pub async fn remember(
282 &self,
283 conversation_id: ConversationId,
284 role: &str,
285 content: &str,
286 ) -> Result<MessageId, MemoryError> {
287 let message_id = self
288 .sqlite
289 .save_message(conversation_id, role, content)
290 .await?;
291
292 if let Some(qdrant) = &self.qdrant
293 && self.provider.supports_embeddings()
294 {
295 match self.provider.embed(content).await {
296 Ok(vector) => {
297 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
299 if let Err(e) = qdrant.ensure_collection(vector_size).await {
300 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
301 } else if let Err(e) = qdrant
302 .store(
303 message_id,
304 conversation_id,
305 role,
306 vector,
307 MessageKind::Regular,
308 &self.embedding_model,
309 )
310 .await
311 {
312 tracing::warn!("Failed to store embedding: {e:#}");
313 }
314 }
315 Err(e) => {
316 tracing::warn!("Failed to generate embedding: {e:#}");
317 }
318 }
319 }
320
321 Ok(message_id)
322 }
323
324 pub async fn remember_with_parts(
333 &self,
334 conversation_id: ConversationId,
335 role: &str,
336 content: &str,
337 parts_json: &str,
338 ) -> Result<(MessageId, bool), MemoryError> {
339 let message_id = self
340 .sqlite
341 .save_message_with_parts(conversation_id, role, content, parts_json)
342 .await?;
343
344 let mut embedding_stored = false;
345
346 if let Some(qdrant) = &self.qdrant
347 && self.provider.supports_embeddings()
348 {
349 match self.provider.embed(content).await {
350 Ok(vector) => {
351 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
352 if let Err(e) = qdrant.ensure_collection(vector_size).await {
353 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
354 } else if let Err(e) = qdrant
355 .store(
356 message_id,
357 conversation_id,
358 role,
359 vector,
360 MessageKind::Regular,
361 &self.embedding_model,
362 )
363 .await
364 {
365 tracing::warn!("Failed to store embedding: {e:#}");
366 } else {
367 embedding_stored = true;
368 }
369 }
370 Err(e) => {
371 tracing::warn!("Failed to generate embedding: {e:#}");
372 }
373 }
374 }
375
376 Ok((message_id, embedding_stored))
377 }
378
379 pub async fn save_only(
387 &self,
388 conversation_id: ConversationId,
389 role: &str,
390 content: &str,
391 parts_json: &str,
392 ) -> Result<MessageId, MemoryError> {
393 self.sqlite
394 .save_message_with_parts(conversation_id, role, content, parts_json)
395 .await
396 }
397
398 #[allow(clippy::too_many_lines)]
408 pub async fn recall(
409 &self,
410 query: &str,
411 limit: usize,
412 filter: Option<SearchFilter>,
413 ) -> Result<Vec<RecalledMessage>, MemoryError> {
414 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
415
416 let keyword_results = match self
418 .sqlite
419 .keyword_search(query, limit * 2, conversation_id)
420 .await
421 {
422 Ok(results) => results,
423 Err(e) => {
424 tracing::warn!("FTS5 keyword search failed: {e:#}");
425 Vec::new()
426 }
427 };
428
429 let vector_results = if let Some(qdrant) = &self.qdrant
431 && self.provider.supports_embeddings()
432 {
433 let query_vector = self.provider.embed(query).await?;
434 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
435 qdrant.ensure_collection(vector_size).await?;
436 qdrant.search(&query_vector, limit * 2, filter).await?
437 } else {
438 Vec::new()
439 };
440
441 let mut scores: std::collections::HashMap<MessageId, f64> =
443 std::collections::HashMap::new();
444
445 if !vector_results.is_empty() {
446 let max_vs = vector_results
447 .iter()
448 .map(|r| r.score)
449 .fold(f32::NEG_INFINITY, f32::max);
450 let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
451 for r in &vector_results {
452 let normalized = f64::from(r.score / norm);
453 *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
454 }
455 }
456
457 if !keyword_results.is_empty() {
458 let max_ks = keyword_results
459 .iter()
460 .map(|r| r.1)
461 .fold(f64::NEG_INFINITY, f64::max);
462 let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
463 for &(msg_id, score) in &keyword_results {
464 let normalized = score / norm;
465 *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
466 }
467 }
468
469 if scores.is_empty() {
470 return Ok(Vec::new());
471 }
472
473 let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
475 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
476
477 if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
479 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
480 match self.sqlite.message_timestamps(&ids).await {
481 Ok(timestamps) => {
482 apply_temporal_decay(
483 &mut ranked,
484 ×tamps,
485 self.temporal_decay_half_life_days,
486 );
487 ranked
488 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
489 }
490 Err(e) => {
491 tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
492 }
493 }
494 }
495
496 if self.mmr_enabled && !vector_results.is_empty() {
498 if let Some(qdrant) = &self.qdrant {
499 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
500 match qdrant.get_vectors(&ids).await {
501 Ok(vec_map) if !vec_map.is_empty() => {
502 ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
503 }
504 Ok(_) => {
505 ranked.truncate(limit);
506 }
507 Err(e) => {
508 tracing::warn!("MMR: failed to fetch vectors: {e:#}");
509 ranked.truncate(limit);
510 }
511 }
512 } else {
513 ranked.truncate(limit);
514 }
515 } else {
516 ranked.truncate(limit);
517 }
518
519 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
520 let messages = self.sqlite.messages_by_ids(&ids).await?;
521 let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
522
523 let recalled = ranked
524 .iter()
525 .filter_map(|(msg_id, score)| {
526 msg_map.get(msg_id).map(|msg| RecalledMessage {
527 message: msg.clone(),
528 #[expect(clippy::cast_possible_truncation)]
529 score: *score as f32,
530 })
531 })
532 .collect();
533
534 Ok(recalled)
535 }
536
537 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
543 match &self.qdrant {
544 Some(qdrant) => qdrant.has_embedding(message_id).await,
545 None => Ok(false),
546 }
547 }
548
549 pub async fn embed_missing(&self) -> Result<usize, MemoryError> {
558 let Some(qdrant) = &self.qdrant else {
559 return Ok(0);
560 };
561 if !self.provider.supports_embeddings() {
562 return Ok(0);
563 }
564
565 let unembedded = self.sqlite.unembedded_message_ids(Some(1000)).await?;
566
567 if unembedded.is_empty() {
568 return Ok(0);
569 }
570
571 let probe = self.provider.embed("probe").await?;
572 let vector_size = u64::try_from(probe.len())?;
573 qdrant.ensure_collection(vector_size).await?;
574
575 let mut count = 0;
576 for (msg_id, conversation_id, role, content) in &unembedded {
577 match self.provider.embed(content).await {
578 Ok(vector) => {
579 if let Err(e) = qdrant
580 .store(
581 *msg_id,
582 *conversation_id,
583 role,
584 vector,
585 MessageKind::Regular,
586 &self.embedding_model,
587 )
588 .await
589 {
590 tracing::warn!("Failed to store embedding for msg {msg_id}: {e:#}");
591 continue;
592 }
593 count += 1;
594 }
595 Err(e) => {
596 tracing::warn!("Failed to embed msg {msg_id}: {e:#}");
597 }
598 }
599 }
600
601 tracing::info!("Embedded {count}/{} missing messages", unembedded.len());
602 Ok(count)
603 }
604
605 pub async fn store_session_summary(
611 &self,
612 conversation_id: ConversationId,
613 summary_text: &str,
614 ) -> Result<(), MemoryError> {
615 let Some(qdrant) = &self.qdrant else {
616 return Ok(());
617 };
618 if !self.provider.supports_embeddings() {
619 return Ok(());
620 }
621
622 let vector = self.provider.embed(summary_text).await?;
623 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
624 qdrant
625 .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
626 .await?;
627
628 let payload = serde_json::json!({
629 "conversation_id": conversation_id.0,
630 "summary_text": summary_text,
631 });
632
633 qdrant
634 .store_to_collection(SESSION_SUMMARIES_COLLECTION, payload, vector)
635 .await?;
636
637 tracing::debug!(
638 conversation_id = conversation_id.0,
639 "stored session summary"
640 );
641 Ok(())
642 }
643
644 pub async fn search_session_summaries(
650 &self,
651 query: &str,
652 limit: usize,
653 exclude_conversation_id: Option<ConversationId>,
654 ) -> Result<Vec<SessionSummaryResult>, MemoryError> {
655 let Some(qdrant) = &self.qdrant else {
656 return Ok(Vec::new());
657 };
658 if !self.provider.supports_embeddings() {
659 return Ok(Vec::new());
660 }
661
662 let vector = self.provider.embed(query).await?;
663 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
664 qdrant
665 .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
666 .await?;
667
668 let filter = exclude_conversation_id.map(|cid| VectorFilter {
669 must: vec![],
670 must_not: vec![FieldCondition {
671 field: "conversation_id".into(),
672 value: FieldValue::Integer(cid.0),
673 }],
674 });
675
676 let points = qdrant
677 .search_collection(SESSION_SUMMARIES_COLLECTION, &vector, limit, filter)
678 .await?;
679
680 let results = points
681 .into_iter()
682 .filter_map(|point| {
683 let summary_text = point.payload.get("summary_text")?.as_str()?.to_owned();
684 let conversation_id =
685 ConversationId(point.payload.get("conversation_id")?.as_i64()?);
686 Some(SessionSummaryResult {
687 summary_text,
688 score: point.score,
689 conversation_id,
690 })
691 })
692 .collect();
693
694 Ok(results)
695 }
696
697 #[must_use]
699 pub fn sqlite(&self) -> &SqliteStore {
700 &self.sqlite
701 }
702
703 pub async fn is_vector_store_connected(&self) -> bool {
708 match self.qdrant.as_ref() {
709 Some(store) => store.health_check().await,
710 None => false,
711 }
712 }
713
714 #[must_use]
716 pub fn has_vector_store(&self) -> bool {
717 self.qdrant.is_some()
718 }
719
720 pub async fn message_count(&self, conversation_id: ConversationId) -> Result<i64, MemoryError> {
726 self.sqlite.count_messages(conversation_id).await
727 }
728
729 pub async fn unsummarized_message_count(
735 &self,
736 conversation_id: ConversationId,
737 ) -> Result<i64, MemoryError> {
738 let after_id = self
739 .sqlite
740 .latest_summary_last_message_id(conversation_id)
741 .await?
742 .unwrap_or(MessageId(0));
743 self.sqlite
744 .count_messages_after(conversation_id, after_id)
745 .await
746 }
747
748 pub async fn load_summaries(
754 &self,
755 conversation_id: ConversationId,
756 ) -> Result<Vec<Summary>, MemoryError> {
757 let rows = self.sqlite.load_summaries(conversation_id).await?;
758 let summaries = rows
759 .into_iter()
760 .map(
761 |(
762 id,
763 conversation_id,
764 content,
765 first_message_id,
766 last_message_id,
767 token_estimate,
768 )| {
769 Summary {
770 id,
771 conversation_id,
772 content,
773 first_message_id,
774 last_message_id,
775 token_estimate,
776 }
777 },
778 )
779 .collect();
780 Ok(summaries)
781 }
782
783 pub async fn summarize(
791 &self,
792 conversation_id: ConversationId,
793 message_count: usize,
794 ) -> Result<Option<i64>, MemoryError> {
795 let total = self.sqlite.count_messages(conversation_id).await?;
796
797 if total <= i64::try_from(message_count)? {
798 return Ok(None);
799 }
800
801 let after_id = self
802 .sqlite
803 .latest_summary_last_message_id(conversation_id)
804 .await?
805 .unwrap_or(MessageId(0));
806
807 let messages = self
808 .sqlite
809 .load_messages_range(conversation_id, after_id, message_count)
810 .await?;
811
812 if messages.is_empty() {
813 return Ok(None);
814 }
815
816 let prompt = build_summarization_prompt(&messages);
817 let chat_messages = vec![Message {
818 role: Role::User,
819 content: prompt,
820 parts: vec![],
821 }];
822
823 let structured = match self
824 .provider
825 .chat_typed_erased::<StructuredSummary>(&chat_messages)
826 .await
827 {
828 Ok(s) => s,
829 Err(e) => {
830 tracing::warn!(
831 "structured summarization failed, falling back to plain text: {e:#}"
832 );
833 let plain = self.provider.chat(&chat_messages).await?;
834 StructuredSummary {
835 summary: plain,
836 key_facts: vec![],
837 entities: vec![],
838 }
839 }
840 };
841 let summary_text = &structured.summary;
842
843 let token_estimate = i64::try_from(estimate_tokens(summary_text))?;
844 let first_message_id = messages[0].0;
845 let last_message_id = messages[messages.len() - 1].0;
846
847 let summary_id = self
848 .sqlite
849 .save_summary(
850 conversation_id,
851 summary_text,
852 first_message_id,
853 last_message_id,
854 token_estimate,
855 )
856 .await?;
857
858 if let Some(qdrant) = &self.qdrant
859 && self.provider.supports_embeddings()
860 {
861 match self.provider.embed(summary_text).await {
862 Ok(vector) => {
863 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
865 if let Err(e) = qdrant.ensure_collection(vector_size).await {
866 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
867 } else if let Err(e) = qdrant
868 .store(
869 MessageId(summary_id),
870 conversation_id,
871 "system",
872 vector,
873 MessageKind::Summary,
874 &self.embedding_model,
875 )
876 .await
877 {
878 tracing::warn!("Failed to embed summary: {e:#}");
879 }
880 }
881 Err(e) => {
882 tracing::warn!("Failed to generate summary embedding: {e:#}");
883 }
884 }
885 }
886
887 if !structured.key_facts.is_empty() {
889 self.store_key_facts(conversation_id, summary_id, &structured.key_facts)
890 .await;
891 }
892
893 Ok(Some(summary_id))
894 }
895
896 async fn store_key_facts(
897 &self,
898 conversation_id: ConversationId,
899 source_summary_id: i64,
900 key_facts: &[String],
901 ) {
902 let Some(qdrant) = &self.qdrant else {
903 return;
904 };
905 if !self.provider.supports_embeddings() {
906 return;
907 }
908
909 let Some(first_fact) = key_facts.first() else {
910 return;
911 };
912 let first_vector = match self.provider.embed(first_fact).await {
913 Ok(v) => v,
914 Err(e) => {
915 tracing::warn!("Failed to embed key fact: {e:#}");
916 return;
917 }
918 };
919 let vector_size = u64::try_from(first_vector.len()).unwrap_or(896);
920 if let Err(e) = qdrant
921 .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
922 .await
923 {
924 tracing::warn!("Failed to ensure key_facts collection: {e:#}");
925 return;
926 }
927
928 let first_payload = serde_json::json!({
929 "conversation_id": conversation_id.0,
930 "fact_text": first_fact,
931 "source_summary_id": source_summary_id,
932 });
933 if let Err(e) = qdrant
934 .store_to_collection(KEY_FACTS_COLLECTION, first_payload, first_vector)
935 .await
936 {
937 tracing::warn!("Failed to store key fact: {e:#}");
938 }
939
940 for fact in &key_facts[1..] {
941 match self.provider.embed(fact).await {
942 Ok(vector) => {
943 let payload = serde_json::json!({
944 "conversation_id": conversation_id.0,
945 "fact_text": fact,
946 "source_summary_id": source_summary_id,
947 });
948 if let Err(e) = qdrant
949 .store_to_collection(KEY_FACTS_COLLECTION, payload, vector)
950 .await
951 {
952 tracing::warn!("Failed to store key fact: {e:#}");
953 }
954 }
955 Err(e) => {
956 tracing::warn!("Failed to embed key fact: {e:#}");
957 }
958 }
959 }
960 }
961
962 pub async fn search_key_facts(
968 &self,
969 query: &str,
970 limit: usize,
971 ) -> Result<Vec<String>, MemoryError> {
972 let Some(qdrant) = &self.qdrant else {
973 return Ok(Vec::new());
974 };
975 if !self.provider.supports_embeddings() {
976 return Ok(Vec::new());
977 }
978
979 let vector = self.provider.embed(query).await?;
980 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
981 qdrant
982 .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
983 .await?;
984
985 let points = qdrant
986 .search_collection(KEY_FACTS_COLLECTION, &vector, limit, None)
987 .await?;
988
989 let facts = points
990 .into_iter()
991 .filter_map(|p| p.payload.get("fact_text")?.as_str().map(String::from))
992 .collect();
993
994 Ok(facts)
995 }
996}
997
998#[cfg(test)]
999mod tests {
1000 use zeph_llm::mock::MockProvider;
1001 use zeph_llm::provider::Role;
1002
1003 use super::*;
1004
1005 fn test_provider() -> AnyProvider {
1006 AnyProvider::Mock(MockProvider::default())
1007 }
1008
1009 async fn test_semantic_memory(_supports_embeddings: bool) -> SemanticMemory {
1010 let provider = test_provider();
1011 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1012
1013 SemanticMemory {
1014 sqlite,
1015 qdrant: None,
1016 provider,
1017 embedding_model: "test-model".into(),
1018 vector_weight: 0.7,
1019 keyword_weight: 0.3,
1020 temporal_decay_enabled: false,
1021 temporal_decay_half_life_days: 30,
1022 mmr_enabled: false,
1023 mmr_lambda: 0.7,
1024 }
1025 }
1026
1027 #[tokio::test]
1028 async fn remember_saves_to_sqlite() {
1029 let memory = test_semantic_memory(false).await;
1030
1031 let cid = memory.sqlite.create_conversation().await.unwrap();
1032 let msg_id = memory.remember(cid, "user", "hello").await.unwrap();
1033
1034 assert_eq!(msg_id, MessageId(1));
1035
1036 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1037 assert_eq!(history.len(), 1);
1038 assert_eq!(history[0].role, Role::User);
1039 assert_eq!(history[0].content, "hello");
1040 }
1041
1042 #[tokio::test]
1043 async fn remember_with_parts_saves_parts_json() {
1044 let memory = test_semantic_memory(false).await;
1045 let cid = memory.sqlite.create_conversation().await.unwrap();
1046
1047 let parts_json =
1048 r#"[{"kind":"ToolOutput","tool_name":"shell","body":"hello","compacted_at":null}]"#;
1049 let (msg_id, _embedding_stored) = memory
1050 .remember_with_parts(cid, "assistant", "tool output", parts_json)
1051 .await
1052 .unwrap();
1053 assert!(msg_id > MessageId(0));
1054
1055 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1056 assert_eq!(history.len(), 1);
1057 assert_eq!(history[0].content, "tool output");
1058 }
1059
1060 #[tokio::test]
1061 async fn recall_returns_empty_without_qdrant() {
1062 let memory = test_semantic_memory(true).await;
1063
1064 let recalled = memory.recall("test", 5, None).await.unwrap();
1065 assert!(recalled.is_empty());
1066 }
1067
1068 #[tokio::test]
1069 async fn has_embedding_without_qdrant() {
1070 let memory = test_semantic_memory(true).await;
1071
1072 let has_embedding = memory.has_embedding(MessageId(1)).await.unwrap();
1073 assert!(!has_embedding);
1074 }
1075
1076 #[tokio::test]
1077 async fn embed_missing_without_qdrant() {
1078 let memory = test_semantic_memory(true).await;
1079
1080 let count = memory.embed_missing().await.unwrap();
1081 assert_eq!(count, 0);
1082 }
1083
1084 #[tokio::test]
1085 async fn sqlite_accessor() {
1086 let memory = test_semantic_memory(false).await;
1087
1088 let cid = memory.sqlite().create_conversation().await.unwrap();
1089 assert_eq!(cid, ConversationId(1));
1090
1091 memory
1092 .sqlite()
1093 .save_message(cid, "user", "test")
1094 .await
1095 .unwrap();
1096
1097 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1098 assert_eq!(history.len(), 1);
1099 }
1100
1101 #[tokio::test]
1102 async fn has_vector_store_returns_false_when_unavailable() {
1103 let memory = test_semantic_memory(false).await;
1104 assert!(!memory.has_vector_store());
1105 }
1106
1107 #[tokio::test]
1108 async fn is_vector_store_connected_returns_false_when_unavailable() {
1109 let memory = test_semantic_memory(false).await;
1110 assert!(!memory.is_vector_store_connected().await);
1111 }
1112
1113 #[tokio::test]
1114 async fn recall_returns_empty_when_embeddings_not_supported() {
1115 let memory = test_semantic_memory(false).await;
1116
1117 let recalled = memory.recall("test", 5, None).await.unwrap();
1118 assert!(recalled.is_empty());
1119 }
1120
1121 #[tokio::test]
1122 async fn embed_missing_returns_zero_when_embeddings_not_supported() {
1123 let memory = test_semantic_memory(false).await;
1124
1125 let cid = memory.sqlite().create_conversation().await.unwrap();
1126 memory
1127 .sqlite()
1128 .save_message(cid, "user", "test")
1129 .await
1130 .unwrap();
1131
1132 let count = memory.embed_missing().await.unwrap();
1133 assert_eq!(count, 0);
1134 }
1135
1136 #[test]
1137 fn estimate_tokens_ascii() {
1138 let text = "Hello, world!";
1140 assert_eq!(estimate_tokens(text), 3);
1141 }
1142
1143 #[test]
1144 fn estimate_tokens_unicode() {
1145 let text = "Привет мир";
1147 assert_eq!(estimate_tokens(text), 2);
1148 }
1149
1150 #[test]
1151 fn estimate_tokens_empty() {
1152 assert_eq!(estimate_tokens(""), 0);
1153 }
1154
1155 #[test]
1156 fn estimate_tokens_cjk() {
1157 let text = "你好世界テスト日";
1159 assert_eq!(estimate_tokens(text), 2);
1160 }
1161
1162 #[tokio::test]
1163 async fn message_count_empty_conversation() {
1164 let memory = test_semantic_memory(false).await;
1165 let cid = memory.sqlite().create_conversation().await.unwrap();
1166
1167 let count = memory.message_count(cid).await.unwrap();
1168 assert_eq!(count, 0);
1169 }
1170
1171 #[tokio::test]
1172 async fn message_count_after_saves() {
1173 let memory = test_semantic_memory(false).await;
1174 let cid = memory.sqlite().create_conversation().await.unwrap();
1175
1176 memory.remember(cid, "user", "msg1").await.unwrap();
1177 memory.remember(cid, "assistant", "msg2").await.unwrap();
1178
1179 let count = memory.message_count(cid).await.unwrap();
1180 assert_eq!(count, 2);
1181 }
1182
1183 #[tokio::test]
1184 async fn unsummarized_count_decreases_after_summary() {
1185 let memory = test_semantic_memory(false).await;
1186 let cid = memory.sqlite().create_conversation().await.unwrap();
1187
1188 for i in 0..10 {
1189 memory
1190 .remember(cid, "user", &format!("msg{i}"))
1191 .await
1192 .unwrap();
1193 }
1194 assert_eq!(memory.unsummarized_message_count(cid).await.unwrap(), 10);
1195
1196 memory.summarize(cid, 5).await.unwrap();
1197
1198 assert!(memory.unsummarized_message_count(cid).await.unwrap() < 10);
1199 assert_eq!(memory.message_count(cid).await.unwrap(), 10);
1200 }
1201
1202 #[tokio::test]
1203 async fn load_summaries_empty() {
1204 let memory = test_semantic_memory(false).await;
1205 let cid = memory.sqlite().create_conversation().await.unwrap();
1206
1207 let summaries = memory.load_summaries(cid).await.unwrap();
1208 assert!(summaries.is_empty());
1209 }
1210
1211 #[tokio::test]
1212 async fn load_summaries_ordered() {
1213 let memory = test_semantic_memory(false).await;
1214 let cid = memory.sqlite().create_conversation().await.unwrap();
1215
1216 let msg_id1 = memory.remember(cid, "user", "m1").await.unwrap();
1217 let msg_id2 = memory.remember(cid, "assistant", "m2").await.unwrap();
1218 let msg_id3 = memory.remember(cid, "user", "m3").await.unwrap();
1219
1220 let s1 = memory
1221 .sqlite()
1222 .save_summary(cid, "summary1", msg_id1, msg_id2, 3)
1223 .await
1224 .unwrap();
1225 let s2 = memory
1226 .sqlite()
1227 .save_summary(cid, "summary2", msg_id2, msg_id3, 3)
1228 .await
1229 .unwrap();
1230
1231 let summaries = memory.load_summaries(cid).await.unwrap();
1232 assert_eq!(summaries.len(), 2);
1233 assert_eq!(summaries[0].id, s1);
1234 assert_eq!(summaries[0].content, "summary1");
1235 assert_eq!(summaries[1].id, s2);
1236 assert_eq!(summaries[1].content, "summary2");
1237 }
1238
1239 #[tokio::test]
1240 async fn summarize_below_threshold() {
1241 let memory = test_semantic_memory(false).await;
1242 let cid = memory.sqlite().create_conversation().await.unwrap();
1243
1244 memory.remember(cid, "user", "hello").await.unwrap();
1245
1246 let result = memory.summarize(cid, 10).await.unwrap();
1247 assert!(result.is_none());
1248 }
1249
1250 #[tokio::test]
1251 async fn summarize_stores_summary() {
1252 let memory = test_semantic_memory(false).await;
1253 let cid = memory.sqlite().create_conversation().await.unwrap();
1254
1255 for i in 0..5 {
1256 memory
1257 .remember(cid, "user", &format!("message {i}"))
1258 .await
1259 .unwrap();
1260 }
1261
1262 let summary_id = memory.summarize(cid, 3).await.unwrap();
1263 assert!(summary_id.is_some());
1264
1265 let summaries = memory.load_summaries(cid).await.unwrap();
1266 assert_eq!(summaries.len(), 1);
1267 assert_eq!(summaries[0].id, summary_id.unwrap());
1268 assert!(!summaries[0].content.is_empty());
1269 }
1270
1271 #[tokio::test]
1272 async fn summarize_respects_previous_summaries() {
1273 let memory = test_semantic_memory(false).await;
1274 let cid = memory.sqlite().create_conversation().await.unwrap();
1275
1276 for i in 0..10 {
1277 memory
1278 .remember(cid, "user", &format!("message {i}"))
1279 .await
1280 .unwrap();
1281 }
1282
1283 let s1 = memory.summarize(cid, 3).await.unwrap();
1284 assert!(s1.is_some());
1285
1286 let s2 = memory.summarize(cid, 3).await.unwrap();
1287 assert!(s2.is_some());
1288
1289 let summaries = memory.load_summaries(cid).await.unwrap();
1290 assert_eq!(summaries.len(), 2);
1291 assert!(summaries[0].last_message_id < summaries[1].first_message_id);
1292 }
1293
1294 #[tokio::test]
1295 async fn remember_multiple_messages_increments_ids() {
1296 let memory = test_semantic_memory(false).await;
1297 let cid = memory.sqlite.create_conversation().await.unwrap();
1298
1299 let id1 = memory.remember(cid, "user", "first").await.unwrap();
1300 let id2 = memory.remember(cid, "assistant", "second").await.unwrap();
1301 let id3 = memory.remember(cid, "user", "third").await.unwrap();
1302
1303 assert!(id1 < id2);
1304 assert!(id2 < id3);
1305 }
1306
1307 #[tokio::test]
1308 async fn message_count_across_conversations() {
1309 let memory = test_semantic_memory(false).await;
1310 let cid1 = memory.sqlite().create_conversation().await.unwrap();
1311 let cid2 = memory.sqlite().create_conversation().await.unwrap();
1312
1313 memory.remember(cid1, "user", "msg1").await.unwrap();
1314 memory.remember(cid1, "user", "msg2").await.unwrap();
1315 memory.remember(cid2, "user", "msg3").await.unwrap();
1316
1317 assert_eq!(memory.message_count(cid1).await.unwrap(), 2);
1318 assert_eq!(memory.message_count(cid2).await.unwrap(), 1);
1319 }
1320
1321 #[tokio::test]
1322 async fn summarize_exact_threshold_returns_none() {
1323 let memory = test_semantic_memory(false).await;
1324 let cid = memory.sqlite().create_conversation().await.unwrap();
1325
1326 for i in 0..3 {
1327 memory
1328 .remember(cid, "user", &format!("msg {i}"))
1329 .await
1330 .unwrap();
1331 }
1332
1333 let result = memory.summarize(cid, 3).await.unwrap();
1334 assert!(result.is_none());
1335 }
1336
1337 #[tokio::test]
1338 async fn summarize_one_above_threshold_produces_summary() {
1339 let memory = test_semantic_memory(false).await;
1340 let cid = memory.sqlite().create_conversation().await.unwrap();
1341
1342 for i in 0..4 {
1343 memory
1344 .remember(cid, "user", &format!("msg {i}"))
1345 .await
1346 .unwrap();
1347 }
1348
1349 let result = memory.summarize(cid, 3).await.unwrap();
1350 assert!(result.is_some());
1351 }
1352
1353 #[tokio::test]
1354 async fn summary_fields_populated() {
1355 let memory = test_semantic_memory(false).await;
1356 let cid = memory.sqlite().create_conversation().await.unwrap();
1357
1358 for i in 0..5 {
1359 memory
1360 .remember(cid, "user", &format!("msg {i}"))
1361 .await
1362 .unwrap();
1363 }
1364
1365 memory.summarize(cid, 3).await.unwrap();
1366 let summaries = memory.load_summaries(cid).await.unwrap();
1367 let s = &summaries[0];
1368
1369 assert_eq!(s.conversation_id, cid);
1370 assert!(s.first_message_id > MessageId(0));
1371 assert!(s.last_message_id >= s.first_message_id);
1372 assert!(s.token_estimate >= 0);
1373 assert!(!s.content.is_empty());
1374 }
1375
1376 #[test]
1377 fn build_summarization_prompt_format() {
1378 let messages = vec![
1379 (MessageId(1), "user".into(), "Hello".into()),
1380 (MessageId(2), "assistant".into(), "Hi there".into()),
1381 ];
1382 let prompt = build_summarization_prompt(&messages);
1383 assert!(prompt.contains("user: Hello"));
1384 assert!(prompt.contains("assistant: Hi there"));
1385 assert!(prompt.contains("key_facts"));
1386 }
1387
1388 #[test]
1389 fn build_summarization_prompt_empty() {
1390 let messages: Vec<(MessageId, String, String)> = vec![];
1391 let prompt = build_summarization_prompt(&messages);
1392 assert!(prompt.contains("key_facts"));
1393 }
1394
1395 #[test]
1396 fn structured_summary_deserialize() {
1397 let json = r#"{"summary":"s","key_facts":["f1","f2"],"entities":["e1"]}"#;
1398 let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1399 assert_eq!(ss.summary, "s");
1400 assert_eq!(ss.key_facts.len(), 2);
1401 assert_eq!(ss.entities.len(), 1);
1402 }
1403
1404 #[test]
1405 fn structured_summary_empty_facts() {
1406 let json = r#"{"summary":"s","key_facts":[],"entities":[]}"#;
1407 let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1408 assert!(ss.key_facts.is_empty());
1409 assert!(ss.entities.is_empty());
1410 }
1411
1412 #[tokio::test]
1413 async fn search_key_facts_no_qdrant_empty() {
1414 let memory = test_semantic_memory(false).await;
1415 let facts = memory.search_key_facts("query", 5).await.unwrap();
1416 assert!(facts.is_empty());
1417 }
1418
1419 #[test]
1420 fn recalled_message_debug() {
1421 let recalled = RecalledMessage {
1422 message: Message {
1423 role: Role::User,
1424 content: "test".into(),
1425 parts: vec![],
1426 },
1427 score: 0.95,
1428 };
1429 let dbg = format!("{recalled:?}");
1430 assert!(dbg.contains("RecalledMessage"));
1431 assert!(dbg.contains("0.95"));
1432 }
1433
1434 #[test]
1435 fn summary_clone() {
1436 let summary = Summary {
1437 id: 1,
1438 conversation_id: ConversationId(2),
1439 content: "test summary".into(),
1440 first_message_id: MessageId(1),
1441 last_message_id: MessageId(5),
1442 token_estimate: 10,
1443 };
1444 let cloned = summary.clone();
1445 assert_eq!(summary.id, cloned.id);
1446 assert_eq!(summary.content, cloned.content);
1447 }
1448
1449 #[test]
1450 fn estimate_tokens_short_text() {
1451 assert_eq!(estimate_tokens("ab"), 0);
1453 }
1454
1455 #[test]
1456 fn estimate_tokens_longer_text() {
1457 let text = "a".repeat(100);
1459 assert_eq!(estimate_tokens(&text), 25);
1460 }
1461
1462 #[tokio::test]
1463 async fn remember_preserves_role_mapping() {
1464 let memory = test_semantic_memory(false).await;
1465 let cid = memory.sqlite.create_conversation().await.unwrap();
1466
1467 memory.remember(cid, "user", "u").await.unwrap();
1468 memory.remember(cid, "assistant", "a").await.unwrap();
1469 memory.remember(cid, "system", "s").await.unwrap();
1470
1471 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1472 assert_eq!(history.len(), 3);
1473 assert_eq!(history[0].role, Role::User);
1474 assert_eq!(history[1].role, Role::Assistant);
1475 assert_eq!(history[2].role, Role::System);
1476 }
1477
1478 #[tokio::test]
1479 async fn new_with_invalid_qdrant_url_graceful() {
1480 let mut mock = MockProvider::default();
1481 mock.supports_embeddings = true;
1482 let provider = AnyProvider::Mock(mock);
1483 let result =
1484 SemanticMemory::new(":memory:", "http://127.0.0.1:1", provider, "test-model").await;
1485 assert!(result.is_ok());
1486 }
1487
1488 #[tokio::test]
1489 async fn test_semantic_memory_sqlite_remember_recall_roundtrip() {
1490 let mut mock = MockProvider::default();
1492 mock.supports_embeddings = true;
1493 let provider = AnyProvider::Mock(mock);
1496
1497 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1498 let pool = sqlite.pool().clone();
1499 let qdrant = Some(crate::embedding_store::EmbeddingStore::new_sqlite(pool));
1500
1501 let memory = SemanticMemory {
1502 sqlite,
1503 qdrant,
1504 provider,
1505 embedding_model: "test-model".into(),
1506 vector_weight: 0.7,
1507 keyword_weight: 0.3,
1508 temporal_decay_enabled: false,
1509 temporal_decay_half_life_days: 30,
1510 mmr_enabled: false,
1511 mmr_lambda: 0.7,
1512 };
1513
1514 let cid = memory.sqlite().create_conversation().await.unwrap();
1515
1516 let id1 = memory
1518 .remember(cid, "user", "rust async programming")
1519 .await
1520 .unwrap();
1521 let id2 = memory
1522 .remember(cid, "assistant", "use tokio for async")
1523 .await
1524 .unwrap();
1525 assert!(id1 < id2);
1526
1527 let recalled = memory.recall("rust", 5, None).await.unwrap();
1529 assert!(
1530 !recalled.is_empty(),
1531 "recall must return at least one result"
1532 );
1533
1534 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1536 assert_eq!(history.len(), 2);
1537 assert_eq!(history[0].content, "rust async programming");
1538 }
1539
1540 #[tokio::test]
1541 async fn remember_with_embeddings_supported_but_no_qdrant() {
1542 let memory = test_semantic_memory(true).await;
1543 let cid = memory.sqlite.create_conversation().await.unwrap();
1544
1545 let msg_id = memory.remember(cid, "user", "hello embed").await.unwrap();
1546 assert!(msg_id > MessageId(0));
1547
1548 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1549 assert_eq!(history.len(), 1);
1550 assert_eq!(history[0].content, "hello embed");
1551 }
1552
1553 #[tokio::test]
1554 async fn remember_verifies_content_via_load_history() {
1555 let memory = test_semantic_memory(false).await;
1556 let cid = memory.sqlite.create_conversation().await.unwrap();
1557
1558 memory.remember(cid, "user", "alpha").await.unwrap();
1559 memory.remember(cid, "assistant", "beta").await.unwrap();
1560 memory.remember(cid, "user", "gamma").await.unwrap();
1561
1562 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1563 assert_eq!(history.len(), 3);
1564 assert_eq!(history[0].content, "alpha");
1565 assert_eq!(history[1].content, "beta");
1566 assert_eq!(history[2].content, "gamma");
1567 }
1568
1569 #[tokio::test]
1570 async fn message_count_multiple_conversations_isolated() {
1571 let memory = test_semantic_memory(false).await;
1572 let cid1 = memory.sqlite().create_conversation().await.unwrap();
1573 let cid2 = memory.sqlite().create_conversation().await.unwrap();
1574 let cid3 = memory.sqlite().create_conversation().await.unwrap();
1575
1576 for _ in 0..5 {
1577 memory.remember(cid1, "user", "msg").await.unwrap();
1578 }
1579 for _ in 0..3 {
1580 memory.remember(cid2, "user", "msg").await.unwrap();
1581 }
1582
1583 assert_eq!(memory.message_count(cid1).await.unwrap(), 5);
1584 assert_eq!(memory.message_count(cid2).await.unwrap(), 3);
1585 assert_eq!(memory.message_count(cid3).await.unwrap(), 0);
1586 }
1587
1588 #[tokio::test]
1589 async fn summarize_empty_messages_range_returns_none() {
1590 let memory = test_semantic_memory(false).await;
1591 let cid = memory.sqlite().create_conversation().await.unwrap();
1592
1593 for i in 0..6 {
1594 memory
1595 .remember(cid, "user", &format!("msg {i}"))
1596 .await
1597 .unwrap();
1598 }
1599
1600 memory.summarize(cid, 3).await.unwrap();
1601 memory.summarize(cid, 3).await.unwrap();
1602
1603 let summaries = memory.load_summaries(cid).await.unwrap();
1604 assert_eq!(summaries.len(), 2);
1605 }
1606
1607 #[tokio::test]
1608 async fn summarize_token_estimate_populated() {
1609 let memory = test_semantic_memory(false).await;
1610 let cid = memory.sqlite().create_conversation().await.unwrap();
1611
1612 for i in 0..5 {
1613 memory
1614 .remember(cid, "user", &format!("message {i}"))
1615 .await
1616 .unwrap();
1617 }
1618
1619 memory.summarize(cid, 3).await.unwrap();
1620 let summaries = memory.load_summaries(cid).await.unwrap();
1621 let token_est = summaries[0].token_estimate;
1622 let expected = i64::try_from(estimate_tokens(&summaries[0].content)).unwrap();
1623 assert_eq!(token_est, expected);
1624 }
1625
1626 #[tokio::test]
1627 async fn summarize_fails_when_provider_chat_fails() {
1628 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1629 let provider = AnyProvider::Ollama(zeph_llm::ollama::OllamaProvider::new(
1630 "http://127.0.0.1:1",
1631 "test".into(),
1632 "embed".into(),
1633 ));
1634 let memory = SemanticMemory {
1635 sqlite,
1636 qdrant: None,
1637 provider,
1638 embedding_model: "test".into(),
1639 vector_weight: 0.7,
1640 keyword_weight: 0.3,
1641 temporal_decay_enabled: false,
1642 temporal_decay_half_life_days: 30,
1643 mmr_enabled: false,
1644 mmr_lambda: 0.7,
1645 };
1646 let cid = memory.sqlite().create_conversation().await.unwrap();
1647
1648 for i in 0..5 {
1649 memory
1650 .remember(cid, "user", &format!("msg {i}"))
1651 .await
1652 .unwrap();
1653 }
1654
1655 let result = memory.summarize(cid, 3).await;
1656 assert!(result.is_err());
1657 }
1658
1659 #[tokio::test]
1660 async fn embed_missing_without_embedding_support_returns_zero() {
1661 let memory = test_semantic_memory(false).await;
1662 let cid = memory.sqlite().create_conversation().await.unwrap();
1663 memory
1664 .sqlite()
1665 .save_message(cid, "user", "test message")
1666 .await
1667 .unwrap();
1668
1669 let count = memory.embed_missing().await.unwrap();
1670 assert_eq!(count, 0);
1671 }
1672
1673 #[tokio::test]
1674 async fn has_embedding_returns_false_when_no_qdrant() {
1675 let memory = test_semantic_memory(false).await;
1676 let cid = memory.sqlite.create_conversation().await.unwrap();
1677 let msg_id = memory.remember(cid, "user", "test").await.unwrap();
1678 assert!(!memory.has_embedding(msg_id).await.unwrap());
1679 }
1680
1681 #[tokio::test]
1682 async fn recall_empty_without_qdrant_regardless_of_filter() {
1683 let memory = test_semantic_memory(true).await;
1684 let filter = SearchFilter {
1685 conversation_id: Some(ConversationId(1)),
1686 role: None,
1687 };
1688 let recalled = memory.recall("query", 10, Some(filter)).await.unwrap();
1689 assert!(recalled.is_empty());
1690 }
1691
1692 #[tokio::test]
1693 async fn summarize_message_range_bounds() {
1694 let memory = test_semantic_memory(false).await;
1695 let cid = memory.sqlite().create_conversation().await.unwrap();
1696
1697 for i in 0..8 {
1698 memory
1699 .remember(cid, "user", &format!("msg {i}"))
1700 .await
1701 .unwrap();
1702 }
1703
1704 let summary_id = memory.summarize(cid, 4).await.unwrap().unwrap();
1705 let summaries = memory.load_summaries(cid).await.unwrap();
1706 assert_eq!(summaries.len(), 1);
1707 assert_eq!(summaries[0].id, summary_id);
1708 assert!(summaries[0].first_message_id >= MessageId(1));
1709 assert!(summaries[0].last_message_id >= summaries[0].first_message_id);
1710 }
1711
1712 #[test]
1713 fn build_summarization_prompt_preserves_order() {
1714 let messages = vec![
1715 (MessageId(1), "user".into(), "first".into()),
1716 (MessageId(2), "assistant".into(), "second".into()),
1717 (MessageId(3), "user".into(), "third".into()),
1718 ];
1719 let prompt = build_summarization_prompt(&messages);
1720 let first_pos = prompt.find("user: first").unwrap();
1721 let second_pos = prompt.find("assistant: second").unwrap();
1722 let third_pos = prompt.find("user: third").unwrap();
1723 assert!(first_pos < second_pos);
1724 assert!(second_pos < third_pos);
1725 }
1726
1727 #[test]
1728 fn summary_debug() {
1729 let summary = Summary {
1730 id: 1,
1731 conversation_id: ConversationId(2),
1732 content: "test".into(),
1733 first_message_id: MessageId(1),
1734 last_message_id: MessageId(5),
1735 token_estimate: 10,
1736 };
1737 let dbg = format!("{summary:?}");
1738 assert!(dbg.contains("Summary"));
1739 }
1740
1741 #[tokio::test]
1742 async fn message_count_nonexistent_conversation() {
1743 let memory = test_semantic_memory(false).await;
1744 let count = memory.message_count(ConversationId(999)).await.unwrap();
1745 assert_eq!(count, 0);
1746 }
1747
1748 #[tokio::test]
1749 async fn load_summaries_nonexistent_conversation() {
1750 let memory = test_semantic_memory(false).await;
1751 let summaries = memory.load_summaries(ConversationId(999)).await.unwrap();
1752 assert!(summaries.is_empty());
1753 }
1754
1755 #[tokio::test]
1756 async fn store_session_summary_no_qdrant_noop() {
1757 let memory = test_semantic_memory(true).await;
1758 let result = memory
1759 .store_session_summary(ConversationId(1), "test summary")
1760 .await;
1761 assert!(result.is_ok());
1762 }
1763
1764 #[tokio::test]
1765 async fn store_session_summary_no_embeddings_noop() {
1766 let memory = test_semantic_memory(false).await;
1767 let result = memory
1768 .store_session_summary(ConversationId(1), "test summary")
1769 .await;
1770 assert!(result.is_ok());
1771 }
1772
1773 #[tokio::test]
1774 async fn search_session_summaries_no_qdrant_empty() {
1775 let memory = test_semantic_memory(true).await;
1776 let results = memory
1777 .search_session_summaries("query", 5, None)
1778 .await
1779 .unwrap();
1780 assert!(results.is_empty());
1781 }
1782
1783 #[tokio::test]
1784 async fn search_session_summaries_no_embeddings_empty() {
1785 let memory = test_semantic_memory(false).await;
1786 let results = memory
1787 .search_session_summaries("query", 5, Some(ConversationId(1)))
1788 .await
1789 .unwrap();
1790 assert!(results.is_empty());
1791 }
1792
1793 #[test]
1794 fn session_summary_result_debug() {
1795 let result = SessionSummaryResult {
1796 summary_text: "test".into(),
1797 score: 0.9,
1798 conversation_id: ConversationId(1),
1799 };
1800 let dbg = format!("{result:?}");
1801 assert!(dbg.contains("SessionSummaryResult"));
1802 }
1803
1804 #[test]
1805 fn session_summary_result_clone() {
1806 let result = SessionSummaryResult {
1807 summary_text: "test".into(),
1808 score: 0.9,
1809 conversation_id: ConversationId(1),
1810 };
1811 let cloned = result.clone();
1812 assert_eq!(result.summary_text, cloned.summary_text);
1813 assert_eq!(result.conversation_id, cloned.conversation_id);
1814 }
1815
1816 #[tokio::test]
1817 async fn recall_fts5_fallback_without_qdrant() {
1818 let memory = test_semantic_memory(false).await;
1819 let cid = memory.sqlite.create_conversation().await.unwrap();
1820
1821 memory
1822 .remember(cid, "user", "rust programming guide")
1823 .await
1824 .unwrap();
1825 memory
1826 .remember(cid, "assistant", "python tutorial")
1827 .await
1828 .unwrap();
1829 memory
1830 .remember(cid, "user", "advanced rust patterns")
1831 .await
1832 .unwrap();
1833
1834 let recalled = memory.recall("rust", 5, None).await.unwrap();
1835 assert_eq!(recalled.len(), 2);
1836 assert!(recalled[0].score >= recalled[1].score);
1837 }
1838
1839 #[tokio::test]
1840 async fn recall_fts5_fallback_with_filter() {
1841 let memory = test_semantic_memory(false).await;
1842 let cid1 = memory.sqlite.create_conversation().await.unwrap();
1843 let cid2 = memory.sqlite.create_conversation().await.unwrap();
1844
1845 memory.remember(cid1, "user", "hello world").await.unwrap();
1846 memory
1847 .remember(cid2, "user", "hello universe")
1848 .await
1849 .unwrap();
1850
1851 let filter = SearchFilter {
1852 conversation_id: Some(cid1),
1853 role: None,
1854 };
1855 let recalled = memory.recall("hello", 5, Some(filter)).await.unwrap();
1856 assert_eq!(recalled.len(), 1);
1857 }
1858
1859 #[tokio::test]
1860 async fn recall_fts5_no_matches_returns_empty() {
1861 let memory = test_semantic_memory(false).await;
1862 let cid = memory.sqlite.create_conversation().await.unwrap();
1863
1864 memory.remember(cid, "user", "hello world").await.unwrap();
1865
1866 let recalled = memory.recall("nonexistent", 5, None).await.unwrap();
1867 assert!(recalled.is_empty());
1868 }
1869
1870 #[tokio::test]
1871 async fn recall_fts5_respects_limit() {
1872 let memory = test_semantic_memory(false).await;
1873 let cid = memory.sqlite.create_conversation().await.unwrap();
1874
1875 for i in 0..10 {
1876 memory
1877 .remember(cid, "user", &format!("test message number {i}"))
1878 .await
1879 .unwrap();
1880 }
1881
1882 let recalled = memory.recall("test", 3, None).await.unwrap();
1883 assert_eq!(recalled.len(), 3);
1884 }
1885
1886 #[tokio::test]
1889 async fn summarize_fallback_to_plain_text_when_structured_fails() {
1890 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1898 let mut mock = MockProvider::default();
1899 mock.default_response = "plain text summary".into();
1901 let provider = AnyProvider::Mock(mock);
1902
1903 let memory = SemanticMemory {
1904 sqlite,
1905 qdrant: None,
1906 provider,
1907 embedding_model: "test".into(),
1908 vector_weight: 0.7,
1909 keyword_weight: 0.3,
1910 temporal_decay_enabled: false,
1911 temporal_decay_half_life_days: 30,
1912 mmr_enabled: false,
1913 mmr_lambda: 0.7,
1914 };
1915
1916 let cid = memory.sqlite().create_conversation().await.unwrap();
1917 for i in 0..5 {
1918 memory
1919 .remember(cid, "user", &format!("msg {i}"))
1920 .await
1921 .unwrap();
1922 }
1923
1924 let result = memory.summarize(cid, 3).await;
1925 assert!(result.is_ok());
1931 let summaries = memory.load_summaries(cid).await.unwrap();
1932 assert_eq!(summaries.len(), 1);
1933 assert!(!summaries[0].content.is_empty());
1934 }
1935
1936 #[test]
1939 fn temporal_decay_disabled_leaves_scores_unchanged() {
1940 let mut ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
1941 let timestamps = std::collections::HashMap::new();
1942 apply_temporal_decay(&mut ranked, ×tamps, 30);
1943 assert!((ranked[0].1 - 1.0).abs() < f64::EPSILON);
1944 assert!((ranked[1].1 - 0.5).abs() < f64::EPSILON);
1945 }
1946
1947 #[test]
1948 fn temporal_decay_zero_age_preserves_score() {
1949 let now = std::time::SystemTime::now()
1950 .duration_since(std::time::UNIX_EPOCH)
1951 .unwrap_or_default()
1952 .as_secs()
1953 .cast_signed();
1954 let mut ranked = vec![(MessageId(1), 1.0f64)];
1955 let mut timestamps = std::collections::HashMap::new();
1956 timestamps.insert(MessageId(1), now);
1957 apply_temporal_decay(&mut ranked, ×tamps, 30);
1958 assert!((ranked[0].1 - 1.0).abs() < 0.01);
1960 }
1961
1962 #[test]
1963 fn temporal_decay_half_life_halves_score() {
1964 let half_life = 30u32;
1966 let age_secs = i64::from(half_life) * 86400;
1967 let now = std::time::SystemTime::now()
1968 .duration_since(std::time::UNIX_EPOCH)
1969 .unwrap_or_default()
1970 .as_secs()
1971 .cast_signed();
1972 let ts = now - age_secs;
1973 let mut ranked = vec![(MessageId(1), 1.0f64)];
1974 let mut timestamps = std::collections::HashMap::new();
1975 timestamps.insert(MessageId(1), ts);
1976 apply_temporal_decay(&mut ranked, ×tamps, half_life);
1977 assert!(
1979 (ranked[0].1 - 0.5).abs() < 0.01,
1980 "score was {}",
1981 ranked[0].1
1982 );
1983 }
1984
1985 #[test]
1988 fn mmr_empty_input_returns_empty() {
1989 let ranked = vec![];
1990 let vectors = std::collections::HashMap::new();
1991 let result = apply_mmr(&ranked, &vectors, 0.7, 5);
1992 assert!(result.is_empty());
1993 }
1994
1995 #[test]
1996 fn mmr_returns_up_to_limit() {
1997 let ranked = vec![
1998 (MessageId(1), 1.0f64),
1999 (MessageId(2), 0.9f64),
2000 (MessageId(3), 0.8f64),
2001 ];
2002 let mut vectors = std::collections::HashMap::new();
2003 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2004 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2005 vectors.insert(MessageId(3), vec![1.0f32, 0.0]);
2006 let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2007 assert_eq!(result.len(), 2);
2008 }
2009
2010 #[test]
2011 fn mmr_without_vectors_picks_by_relevance() {
2012 let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
2013 let vectors = std::collections::HashMap::new();
2014 let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2015 assert_eq!(result.len(), 2);
2016 assert_eq!(result[0].0, MessageId(1));
2017 }
2018
2019 #[test]
2020 fn mmr_prefers_diverse_over_redundant() {
2021 let ranked = vec![
2023 (MessageId(1), 1.0f64), (MessageId(2), 0.9f64), (MessageId(3), 0.9f64), ];
2027 let mut vectors = std::collections::HashMap::new();
2028 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2029 vectors.insert(MessageId(2), vec![0.0f32, 1.0]); vectors.insert(MessageId(3), vec![1.0f32, 0.0]); let result = apply_mmr(&ranked, &vectors, 0.5, 2);
2032 assert_eq!(result.len(), 2);
2033 assert_eq!(result[0].0, MessageId(1));
2034 assert_eq!(result[1].0, MessageId(2));
2036 }
2037
2038 #[test]
2039 fn temporal_decay_half_life_zero_is_noop() {
2040 let now = std::time::SystemTime::now()
2041 .duration_since(std::time::UNIX_EPOCH)
2042 .unwrap_or_default()
2043 .as_secs()
2044 .cast_signed();
2045 let age_secs = 30i64 * 86400;
2046 let ts = now - age_secs;
2047 let mut ranked = vec![(MessageId(1), 1.0f64)];
2048 let mut timestamps = std::collections::HashMap::new();
2049 timestamps.insert(MessageId(1), ts);
2050 apply_temporal_decay(&mut ranked, ×tamps, 0);
2052 assert!(
2053 (ranked[0].1 - 1.0).abs() < f64::EPSILON,
2054 "score was {}",
2055 ranked[0].1
2056 );
2057 }
2058
2059 #[test]
2060 fn temporal_decay_huge_age_near_zero() {
2061 let now = std::time::SystemTime::now()
2062 .duration_since(std::time::UNIX_EPOCH)
2063 .unwrap_or_default()
2064 .as_secs()
2065 .cast_signed();
2066 let age_secs = 3650i64 * 86400;
2068 let ts = now - age_secs;
2069 let mut ranked = vec![(MessageId(1), 1.0f64)];
2070 let mut timestamps = std::collections::HashMap::new();
2071 timestamps.insert(MessageId(1), ts);
2072 apply_temporal_decay(&mut ranked, ×tamps, 30);
2073 assert!(ranked[0].1 < 0.001, "score was {}", ranked[0].1);
2075 }
2076
2077 #[test]
2078 fn temporal_decay_small_half_life() {
2079 let now = std::time::SystemTime::now()
2081 .duration_since(std::time::UNIX_EPOCH)
2082 .unwrap_or_default()
2083 .as_secs()
2084 .cast_signed();
2085 let ts = now - 7 * 86400i64;
2086 let mut ranked = vec![(MessageId(1), 1.0f64)];
2087 let mut timestamps = std::collections::HashMap::new();
2088 timestamps.insert(MessageId(1), ts);
2089 apply_temporal_decay(&mut ranked, ×tamps, 1);
2090 assert!(ranked[0].1 < 0.01, "score was {}", ranked[0].1);
2091 }
2092
2093 #[test]
2094 fn mmr_lambda_zero_max_diversity() {
2095 let ranked = vec![
2097 (MessageId(1), 1.0f64), (MessageId(2), 0.9f64), (MessageId(3), 0.85f64), ];
2101 let mut vectors = std::collections::HashMap::new();
2102 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2103 vectors.insert(MessageId(2), vec![0.0f32, 1.0]); vectors.insert(MessageId(3), vec![1.0f32, 0.0]); let result = apply_mmr(&ranked, &vectors, 0.0, 3);
2106 assert_eq!(result.len(), 3);
2107 assert_eq!(result[1].0, MessageId(2));
2109 }
2110
2111 #[test]
2112 fn mmr_lambda_one_pure_relevance() {
2113 let ranked = vec![
2115 (MessageId(1), 1.0f64),
2116 (MessageId(2), 0.8f64),
2117 (MessageId(3), 0.6f64),
2118 ];
2119 let mut vectors = std::collections::HashMap::new();
2120 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2121 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2122 vectors.insert(MessageId(3), vec![0.5f32, 0.5]);
2123 let result = apply_mmr(&ranked, &vectors, 1.0, 3);
2124 assert_eq!(result.len(), 3);
2125 assert_eq!(result[0].0, MessageId(1));
2126 assert_eq!(result[1].0, MessageId(2));
2127 assert_eq!(result[2].0, MessageId(3));
2128 }
2129
2130 #[test]
2131 fn mmr_limit_zero_returns_empty() {
2132 let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.8f64)];
2133 let mut vectors = std::collections::HashMap::new();
2134 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2135 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2136 let result = apply_mmr(&ranked, &vectors, 0.7, 0);
2137 assert!(result.is_empty());
2138 }
2139
2140 #[test]
2141 fn mmr_duplicate_vectors_penalizes_second() {
2142 let ranked = vec![
2144 (MessageId(1), 1.0f64),
2145 (MessageId(2), 1.0f64), (MessageId(3), 0.9f64), ];
2148 let mut vectors = std::collections::HashMap::new();
2149 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2150 vectors.insert(MessageId(2), vec![1.0f32, 0.0]); vectors.insert(MessageId(3), vec![0.0f32, 1.0]); let result = apply_mmr(&ranked, &vectors, 0.5, 3);
2153 assert_eq!(result.len(), 3);
2154 assert_eq!(result[0].0, MessageId(1));
2155 assert_eq!(result[1].0, MessageId(3));
2157 }
2158
2159 use proptest::prelude::*;
2162
2163 proptest! {
2164 #[test]
2165 fn estimate_tokens_never_panics(s in ".*") {
2166 let _ = estimate_tokens(&s);
2167 }
2168 }
2169}