1use zeph_llm::any::AnyProvider;
5use zeph_llm::provider::{LlmProvider, Message, MessageMetadata, Role};
6
7use std::sync::Arc;
8
9use crate::embedding_store::{EmbeddingStore, MessageKind, SearchFilter};
10use crate::error::MemoryError;
11use crate::sqlite::SqliteStore;
12use crate::token_counter::TokenCounter;
13use crate::types::{ConversationId, MessageId};
14use crate::vector_store::{FieldCondition, FieldValue, VectorFilter};
15
16const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
17const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
18const CORRECTIONS_COLLECTION: &str = "zeph_corrections";
19
20#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
21pub struct StructuredSummary {
22 pub summary: String,
23 pub key_facts: Vec<String>,
24 pub entities: Vec<String>,
25}
26
27#[derive(Debug)]
28pub struct RecalledMessage {
29 pub message: Message,
30 pub score: f32,
31}
32
33#[derive(Debug, Clone)]
34pub struct Summary {
35 pub id: i64,
36 pub conversation_id: ConversationId,
37 pub content: String,
38 pub first_message_id: MessageId,
39 pub last_message_id: MessageId,
40 pub token_estimate: i64,
41}
42
43#[derive(Debug, Clone)]
44pub struct SessionSummaryResult {
45 pub summary_text: String,
46 pub score: f32,
47 pub conversation_id: ConversationId,
48}
49
50fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
51 if a.len() != b.len() || a.is_empty() {
52 return 0.0;
53 }
54 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
55 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
56 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
57 if norm_a == 0.0 || norm_b == 0.0 {
58 return 0.0;
59 }
60 dot / (norm_a * norm_b)
61}
62
63fn apply_temporal_decay(
64 ranked: &mut [(MessageId, f64)],
65 timestamps: &std::collections::HashMap<MessageId, i64>,
66 half_life_days: u32,
67) {
68 if half_life_days == 0 {
69 return;
70 }
71 let now = std::time::SystemTime::now()
72 .duration_since(std::time::UNIX_EPOCH)
73 .unwrap_or_default()
74 .as_secs()
75 .cast_signed();
76 let lambda = std::f64::consts::LN_2 / f64::from(half_life_days);
77
78 for (msg_id, score) in ranked.iter_mut() {
79 if let Some(&ts) = timestamps.get(msg_id) {
80 #[allow(clippy::cast_precision_loss)]
81 let age_days = (now - ts).max(0) as f64 / 86400.0;
82 *score *= (-lambda * age_days).exp();
83 }
84 }
85}
86
87fn apply_mmr(
88 ranked: &[(MessageId, f64)],
89 vectors: &std::collections::HashMap<MessageId, Vec<f32>>,
90 lambda: f32,
91 limit: usize,
92) -> Vec<(MessageId, f64)> {
93 if ranked.is_empty() || limit == 0 {
94 return Vec::new();
95 }
96
97 let lambda = f64::from(lambda);
98 let mut selected: Vec<(MessageId, f64)> = Vec::with_capacity(limit);
99 let mut remaining: Vec<(MessageId, f64)> = ranked.to_vec();
100
101 while selected.len() < limit && !remaining.is_empty() {
102 let best_idx = if selected.is_empty() {
103 0
105 } else {
106 let mut best = 0usize;
107 let mut best_score = f64::NEG_INFINITY;
108
109 for (i, &(cand_id, relevance)) in remaining.iter().enumerate() {
110 let max_sim = if let Some(cand_vec) = vectors.get(&cand_id) {
111 selected
112 .iter()
113 .filter_map(|(sel_id, _)| vectors.get(sel_id))
114 .map(|sel_vec| f64::from(cosine_similarity(cand_vec, sel_vec)))
115 .fold(f64::NEG_INFINITY, f64::max)
116 } else {
117 0.0
118 };
119 let max_sim = if max_sim == f64::NEG_INFINITY {
120 0.0
121 } else {
122 max_sim
123 };
124 let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim;
125 if mmr_score > best_score {
126 best_score = mmr_score;
127 best = i;
128 }
129 }
130 best
131 };
132
133 selected.push(remaining.remove(best_idx));
134 }
135
136 selected
137}
138
139fn build_summarization_prompt(messages: &[(MessageId, String, String)]) -> String {
140 let mut prompt = String::from(
141 "Summarize the following conversation. Extract key facts, decisions, entities, \
142 and context needed to continue the conversation.\n\n\
143 Respond in JSON with fields: summary (string), key_facts (list of strings), \
144 entities (list of strings).\n\nConversation:\n",
145 );
146
147 for (_, role, content) in messages {
148 prompt.push_str(role);
149 prompt.push_str(": ");
150 prompt.push_str(content);
151 prompt.push('\n');
152 }
153
154 prompt
155}
156
157pub struct SemanticMemory {
158 sqlite: SqliteStore,
159 qdrant: Option<EmbeddingStore>,
160 provider: AnyProvider,
161 embedding_model: String,
162 vector_weight: f64,
163 keyword_weight: f64,
164 temporal_decay_enabled: bool,
165 temporal_decay_half_life_days: u32,
166 mmr_enabled: bool,
167 mmr_lambda: f32,
168 pub token_counter: Arc<TokenCounter>,
169}
170
171impl SemanticMemory {
172 pub async fn new(
180 sqlite_path: &str,
181 qdrant_url: &str,
182 provider: AnyProvider,
183 embedding_model: &str,
184 ) -> Result<Self, MemoryError> {
185 Self::with_weights(sqlite_path, qdrant_url, provider, embedding_model, 0.7, 0.3).await
186 }
187
188 pub async fn with_weights(
194 sqlite_path: &str,
195 qdrant_url: &str,
196 provider: AnyProvider,
197 embedding_model: &str,
198 vector_weight: f64,
199 keyword_weight: f64,
200 ) -> Result<Self, MemoryError> {
201 Self::with_weights_and_pool_size(
202 sqlite_path,
203 qdrant_url,
204 provider,
205 embedding_model,
206 vector_weight,
207 keyword_weight,
208 5,
209 )
210 .await
211 }
212
213 pub async fn with_weights_and_pool_size(
219 sqlite_path: &str,
220 qdrant_url: &str,
221 provider: AnyProvider,
222 embedding_model: &str,
223 vector_weight: f64,
224 keyword_weight: f64,
225 pool_size: u32,
226 ) -> Result<Self, MemoryError> {
227 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
228 let pool = sqlite.pool().clone();
229
230 let qdrant = match EmbeddingStore::new(qdrant_url, pool) {
231 Ok(store) => Some(store),
232 Err(e) => {
233 tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
234 None
235 }
236 };
237
238 Ok(Self {
239 sqlite,
240 qdrant,
241 provider,
242 embedding_model: embedding_model.into(),
243 vector_weight,
244 keyword_weight,
245 temporal_decay_enabled: false,
246 temporal_decay_half_life_days: 30,
247 mmr_enabled: false,
248 mmr_lambda: 0.7,
249 token_counter: Arc::new(TokenCounter::new()),
250 })
251 }
252
253 #[must_use]
255 pub fn with_ranking_options(
256 mut self,
257 temporal_decay_enabled: bool,
258 temporal_decay_half_life_days: u32,
259 mmr_enabled: bool,
260 mmr_lambda: f32,
261 ) -> Self {
262 self.temporal_decay_enabled = temporal_decay_enabled;
263 self.temporal_decay_half_life_days = temporal_decay_half_life_days;
264 self.mmr_enabled = mmr_enabled;
265 self.mmr_lambda = mmr_lambda;
266 self
267 }
268
269 #[cfg(any(test, feature = "mock"))]
273 #[must_use]
274 pub fn from_parts(
275 sqlite: SqliteStore,
276 qdrant: Option<EmbeddingStore>,
277 provider: AnyProvider,
278 embedding_model: impl Into<String>,
279 vector_weight: f64,
280 keyword_weight: f64,
281 token_counter: Arc<TokenCounter>,
282 ) -> Self {
283 Self {
284 sqlite,
285 qdrant,
286 provider,
287 embedding_model: embedding_model.into(),
288 vector_weight,
289 keyword_weight,
290 temporal_decay_enabled: false,
291 temporal_decay_half_life_days: 30,
292 mmr_enabled: false,
293 mmr_lambda: 0.7,
294 token_counter,
295 }
296 }
297
298 pub async fn with_sqlite_backend(
304 sqlite_path: &str,
305 provider: AnyProvider,
306 embedding_model: &str,
307 vector_weight: f64,
308 keyword_weight: f64,
309 ) -> Result<Self, MemoryError> {
310 Self::with_sqlite_backend_and_pool_size(
311 sqlite_path,
312 provider,
313 embedding_model,
314 vector_weight,
315 keyword_weight,
316 5,
317 )
318 .await
319 }
320
321 pub async fn with_sqlite_backend_and_pool_size(
327 sqlite_path: &str,
328 provider: AnyProvider,
329 embedding_model: &str,
330 vector_weight: f64,
331 keyword_weight: f64,
332 pool_size: u32,
333 ) -> Result<Self, MemoryError> {
334 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
335 let pool = sqlite.pool().clone();
336 let store = EmbeddingStore::new_sqlite(pool);
337
338 Ok(Self {
339 sqlite,
340 qdrant: Some(store),
341 provider,
342 embedding_model: embedding_model.into(),
343 vector_weight,
344 keyword_weight,
345 temporal_decay_enabled: false,
346 temporal_decay_half_life_days: 30,
347 mmr_enabled: false,
348 mmr_lambda: 0.7,
349 token_counter: Arc::new(TokenCounter::new()),
350 })
351 }
352
353 pub async fn remember(
362 &self,
363 conversation_id: ConversationId,
364 role: &str,
365 content: &str,
366 ) -> Result<MessageId, MemoryError> {
367 let message_id = self
368 .sqlite
369 .save_message(conversation_id, role, content)
370 .await?;
371
372 if let Some(qdrant) = &self.qdrant
373 && self.provider.supports_embeddings()
374 {
375 match self.provider.embed(content).await {
376 Ok(vector) => {
377 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
379 if let Err(e) = qdrant.ensure_collection(vector_size).await {
380 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
381 } else if let Err(e) = qdrant
382 .store(
383 message_id,
384 conversation_id,
385 role,
386 vector,
387 MessageKind::Regular,
388 &self.embedding_model,
389 )
390 .await
391 {
392 tracing::warn!("Failed to store embedding: {e:#}");
393 }
394 }
395 Err(e) => {
396 tracing::warn!("Failed to generate embedding: {e:#}");
397 }
398 }
399 }
400
401 Ok(message_id)
402 }
403
404 pub async fn remember_with_parts(
413 &self,
414 conversation_id: ConversationId,
415 role: &str,
416 content: &str,
417 parts_json: &str,
418 ) -> Result<(MessageId, bool), MemoryError> {
419 let message_id = self
420 .sqlite
421 .save_message_with_parts(conversation_id, role, content, parts_json)
422 .await?;
423
424 let mut embedding_stored = false;
425
426 if let Some(qdrant) = &self.qdrant
427 && self.provider.supports_embeddings()
428 {
429 match self.provider.embed(content).await {
430 Ok(vector) => {
431 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
432 if let Err(e) = qdrant.ensure_collection(vector_size).await {
433 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
434 } else if let Err(e) = qdrant
435 .store(
436 message_id,
437 conversation_id,
438 role,
439 vector,
440 MessageKind::Regular,
441 &self.embedding_model,
442 )
443 .await
444 {
445 tracing::warn!("Failed to store embedding: {e:#}");
446 } else {
447 embedding_stored = true;
448 }
449 }
450 Err(e) => {
451 tracing::warn!("Failed to generate embedding: {e:#}");
452 }
453 }
454 }
455
456 Ok((message_id, embedding_stored))
457 }
458
459 pub async fn save_only(
467 &self,
468 conversation_id: ConversationId,
469 role: &str,
470 content: &str,
471 parts_json: &str,
472 ) -> Result<MessageId, MemoryError> {
473 self.sqlite
474 .save_message_with_parts(conversation_id, role, content, parts_json)
475 .await
476 }
477
478 pub async fn recall(
488 &self,
489 query: &str,
490 limit: usize,
491 filter: Option<SearchFilter>,
492 ) -> Result<Vec<RecalledMessage>, MemoryError> {
493 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
494
495 let keyword_results = match self
497 .sqlite
498 .keyword_search(query, limit * 2, conversation_id)
499 .await
500 {
501 Ok(results) => results,
502 Err(e) => {
503 tracing::warn!("FTS5 keyword search failed: {e:#}");
504 Vec::new()
505 }
506 };
507
508 let vector_results = if let Some(qdrant) = &self.qdrant
510 && self.provider.supports_embeddings()
511 {
512 let query_vector = self.provider.embed(query).await?;
513 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
514 qdrant.ensure_collection(vector_size).await?;
515 qdrant.search(&query_vector, limit * 2, filter).await?
516 } else {
517 Vec::new()
518 };
519
520 self.recall_merge_and_rank(keyword_results, vector_results, limit)
521 .await
522 }
523
524 async fn recall_fts5_raw(
530 &self,
531 query: &str,
532 limit: usize,
533 conversation_id: Option<ConversationId>,
534 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
535 self.sqlite
536 .keyword_search(query, limit * 2, conversation_id)
537 .await
538 }
539
540 async fn recall_vectors_raw(
547 &self,
548 query: &str,
549 limit: usize,
550 filter: Option<SearchFilter>,
551 ) -> Result<Vec<crate::embedding_store::SearchResult>, MemoryError> {
552 let Some(qdrant) = &self.qdrant else {
553 return Ok(Vec::new());
554 };
555 if !self.provider.supports_embeddings() {
556 return Ok(Vec::new());
557 }
558 let query_vector = self.provider.embed(query).await?;
559 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
560 qdrant.ensure_collection(vector_size).await?;
561 qdrant.search(&query_vector, limit * 2, filter).await
562 }
563
564 #[allow(clippy::cast_possible_truncation)]
573 async fn recall_merge_and_rank(
574 &self,
575 keyword_results: Vec<(MessageId, f64)>,
576 vector_results: Vec<crate::embedding_store::SearchResult>,
577 limit: usize,
578 ) -> Result<Vec<RecalledMessage>, MemoryError> {
579 let mut scores: std::collections::HashMap<MessageId, f64> =
580 std::collections::HashMap::new();
581
582 if !vector_results.is_empty() {
583 let max_vs = vector_results
584 .iter()
585 .map(|r| r.score)
586 .fold(f32::NEG_INFINITY, f32::max);
587 let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
588 for r in &vector_results {
589 let normalized = f64::from(r.score / norm);
590 *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
591 }
592 }
593
594 if !keyword_results.is_empty() {
595 let max_ks = keyword_results
596 .iter()
597 .map(|r| r.1)
598 .fold(f64::NEG_INFINITY, f64::max);
599 let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
600 for &(msg_id, score) in &keyword_results {
601 let normalized = score / norm;
602 *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
603 }
604 }
605
606 if scores.is_empty() {
607 return Ok(Vec::new());
608 }
609
610 let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
611 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
612
613 if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
614 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
615 match self.sqlite.message_timestamps(&ids).await {
616 Ok(timestamps) => {
617 apply_temporal_decay(
618 &mut ranked,
619 ×tamps,
620 self.temporal_decay_half_life_days,
621 );
622 ranked
623 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
624 }
625 Err(e) => {
626 tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
627 }
628 }
629 }
630
631 if self.mmr_enabled && !vector_results.is_empty() {
632 if let Some(qdrant) = &self.qdrant {
633 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
634 match qdrant.get_vectors(&ids).await {
635 Ok(vec_map) if !vec_map.is_empty() => {
636 ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
637 }
638 Ok(_) => {
639 ranked.truncate(limit);
640 }
641 Err(e) => {
642 tracing::warn!("MMR: failed to fetch vectors: {e:#}");
643 ranked.truncate(limit);
644 }
645 }
646 } else {
647 ranked.truncate(limit);
648 }
649 } else {
650 ranked.truncate(limit);
651 }
652
653 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
654 let messages = self.sqlite.messages_by_ids(&ids).await?;
655 let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
656
657 let recalled = ranked
658 .iter()
659 .filter_map(|(msg_id, score)| {
660 msg_map.get(msg_id).map(|msg| RecalledMessage {
661 message: msg.clone(),
662 #[expect(clippy::cast_possible_truncation)]
663 score: *score as f32,
664 })
665 })
666 .collect();
667
668 Ok(recalled)
669 }
670
671 pub async fn recall_routed(
680 &self,
681 query: &str,
682 limit: usize,
683 filter: Option<SearchFilter>,
684 router: &dyn crate::router::MemoryRouter,
685 ) -> Result<Vec<RecalledMessage>, MemoryError> {
686 use crate::router::MemoryRoute;
687
688 let route = router.route(query);
689 tracing::debug!(?route, query_len = query.len(), "memory routing decision");
690
691 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
692
693 let (keyword_results, vector_results): (
694 Vec<(MessageId, f64)>,
695 Vec<crate::embedding_store::SearchResult>,
696 ) = match route {
697 MemoryRoute::Keyword => {
698 let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
699 (kw, Vec::new())
700 }
701 MemoryRoute::Semantic => {
702 let vr = self.recall_vectors_raw(query, limit, filter).await?;
703 (Vec::new(), vr)
704 }
705 MemoryRoute::Hybrid => {
706 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
708 Ok(r) => r,
709 Err(e) => {
710 tracing::warn!("FTS5 keyword search failed: {e:#}");
711 Vec::new()
712 }
713 };
714 let vr = self.recall_vectors_raw(query, limit, filter).await?;
717 (kw, vr)
718 }
719 };
720
721 self.recall_merge_and_rank(keyword_results, vector_results, limit)
722 .await
723 }
724
725 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
731 match &self.qdrant {
732 Some(qdrant) => qdrant.has_embedding(message_id).await,
733 None => Ok(false),
734 }
735 }
736
737 pub async fn embed_missing(&self) -> Result<usize, MemoryError> {
746 let Some(qdrant) = &self.qdrant else {
747 return Ok(0);
748 };
749 if !self.provider.supports_embeddings() {
750 return Ok(0);
751 }
752
753 let unembedded = self.sqlite.unembedded_message_ids(Some(1000)).await?;
754
755 if unembedded.is_empty() {
756 return Ok(0);
757 }
758
759 let probe = self.provider.embed("probe").await?;
760 let vector_size = u64::try_from(probe.len())?;
761 qdrant.ensure_collection(vector_size).await?;
762
763 let mut count = 0;
764 for (msg_id, conversation_id, role, content) in &unembedded {
765 match self.provider.embed(content).await {
766 Ok(vector) => {
767 if let Err(e) = qdrant
768 .store(
769 *msg_id,
770 *conversation_id,
771 role,
772 vector,
773 MessageKind::Regular,
774 &self.embedding_model,
775 )
776 .await
777 {
778 tracing::warn!("Failed to store embedding for msg {msg_id}: {e:#}");
779 continue;
780 }
781 count += 1;
782 }
783 Err(e) => {
784 tracing::warn!("Failed to embed msg {msg_id}: {e:#}");
785 }
786 }
787 }
788
789 tracing::info!("Embedded {count}/{} missing messages", unembedded.len());
790 Ok(count)
791 }
792
793 pub async fn store_session_summary(
799 &self,
800 conversation_id: ConversationId,
801 summary_text: &str,
802 ) -> Result<(), MemoryError> {
803 let Some(qdrant) = &self.qdrant else {
804 return Ok(());
805 };
806 if !self.provider.supports_embeddings() {
807 return Ok(());
808 }
809
810 let vector = self.provider.embed(summary_text).await?;
811 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
812 qdrant
813 .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
814 .await?;
815
816 let payload = serde_json::json!({
817 "conversation_id": conversation_id.0,
818 "summary_text": summary_text,
819 });
820
821 qdrant
822 .store_to_collection(SESSION_SUMMARIES_COLLECTION, payload, vector)
823 .await?;
824
825 tracing::debug!(
826 conversation_id = conversation_id.0,
827 "stored session summary"
828 );
829 Ok(())
830 }
831
832 pub async fn search_session_summaries(
838 &self,
839 query: &str,
840 limit: usize,
841 exclude_conversation_id: Option<ConversationId>,
842 ) -> Result<Vec<SessionSummaryResult>, MemoryError> {
843 let Some(qdrant) = &self.qdrant else {
844 return Ok(Vec::new());
845 };
846 if !self.provider.supports_embeddings() {
847 return Ok(Vec::new());
848 }
849
850 let vector = self.provider.embed(query).await?;
851 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
852 qdrant
853 .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
854 .await?;
855
856 let filter = exclude_conversation_id.map(|cid| VectorFilter {
857 must: vec![],
858 must_not: vec![FieldCondition {
859 field: "conversation_id".into(),
860 value: FieldValue::Integer(cid.0),
861 }],
862 });
863
864 let points = qdrant
865 .search_collection(SESSION_SUMMARIES_COLLECTION, &vector, limit, filter)
866 .await?;
867
868 let results = points
869 .into_iter()
870 .filter_map(|point| {
871 let summary_text = point.payload.get("summary_text")?.as_str()?.to_owned();
872 let conversation_id =
873 ConversationId(point.payload.get("conversation_id")?.as_i64()?);
874 Some(SessionSummaryResult {
875 summary_text,
876 score: point.score,
877 conversation_id,
878 })
879 })
880 .collect();
881
882 Ok(results)
883 }
884
885 #[must_use]
887 pub fn sqlite(&self) -> &SqliteStore {
888 &self.sqlite
889 }
890
891 pub async fn is_vector_store_connected(&self) -> bool {
896 match self.qdrant.as_ref() {
897 Some(store) => store.health_check().await,
898 None => false,
899 }
900 }
901
902 #[must_use]
904 pub fn has_vector_store(&self) -> bool {
905 self.qdrant.is_some()
906 }
907
908 pub async fn message_count(&self, conversation_id: ConversationId) -> Result<i64, MemoryError> {
914 self.sqlite.count_messages(conversation_id).await
915 }
916
917 pub async fn unsummarized_message_count(
923 &self,
924 conversation_id: ConversationId,
925 ) -> Result<i64, MemoryError> {
926 let after_id = self
927 .sqlite
928 .latest_summary_last_message_id(conversation_id)
929 .await?
930 .unwrap_or(MessageId(0));
931 self.sqlite
932 .count_messages_after(conversation_id, after_id)
933 .await
934 }
935
936 pub async fn load_summaries(
942 &self,
943 conversation_id: ConversationId,
944 ) -> Result<Vec<Summary>, MemoryError> {
945 let rows = self.sqlite.load_summaries(conversation_id).await?;
946 let summaries = rows
947 .into_iter()
948 .map(
949 |(
950 id,
951 conversation_id,
952 content,
953 first_message_id,
954 last_message_id,
955 token_estimate,
956 )| {
957 Summary {
958 id,
959 conversation_id,
960 content,
961 first_message_id,
962 last_message_id,
963 token_estimate,
964 }
965 },
966 )
967 .collect();
968 Ok(summaries)
969 }
970
971 pub async fn summarize(
979 &self,
980 conversation_id: ConversationId,
981 message_count: usize,
982 ) -> Result<Option<i64>, MemoryError> {
983 let total = self.sqlite.count_messages(conversation_id).await?;
984
985 if total <= i64::try_from(message_count)? {
986 return Ok(None);
987 }
988
989 let after_id = self
990 .sqlite
991 .latest_summary_last_message_id(conversation_id)
992 .await?
993 .unwrap_or(MessageId(0));
994
995 let messages = self
996 .sqlite
997 .load_messages_range(conversation_id, after_id, message_count)
998 .await?;
999
1000 if messages.is_empty() {
1001 return Ok(None);
1002 }
1003
1004 let prompt = build_summarization_prompt(&messages);
1005 let chat_messages = vec![Message {
1006 role: Role::User,
1007 content: prompt,
1008 parts: vec![],
1009 metadata: MessageMetadata::default(),
1010 }];
1011
1012 let structured = match self
1013 .provider
1014 .chat_typed_erased::<StructuredSummary>(&chat_messages)
1015 .await
1016 {
1017 Ok(s) => s,
1018 Err(e) => {
1019 tracing::warn!(
1020 "structured summarization failed, falling back to plain text: {e:#}"
1021 );
1022 let plain = self.provider.chat(&chat_messages).await?;
1023 StructuredSummary {
1024 summary: plain,
1025 key_facts: vec![],
1026 entities: vec![],
1027 }
1028 }
1029 };
1030 let summary_text = &structured.summary;
1031
1032 let token_estimate = i64::try_from(self.token_counter.count_tokens(summary_text))?;
1033 let first_message_id = messages[0].0;
1034 let last_message_id = messages[messages.len() - 1].0;
1035
1036 let summary_id = self
1037 .sqlite
1038 .save_summary(
1039 conversation_id,
1040 summary_text,
1041 first_message_id,
1042 last_message_id,
1043 token_estimate,
1044 )
1045 .await?;
1046
1047 if let Some(qdrant) = &self.qdrant
1048 && self.provider.supports_embeddings()
1049 {
1050 match self.provider.embed(summary_text).await {
1051 Ok(vector) => {
1052 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
1054 if let Err(e) = qdrant.ensure_collection(vector_size).await {
1055 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
1056 } else if let Err(e) = qdrant
1057 .store(
1058 MessageId(summary_id),
1059 conversation_id,
1060 "system",
1061 vector,
1062 MessageKind::Summary,
1063 &self.embedding_model,
1064 )
1065 .await
1066 {
1067 tracing::warn!("Failed to embed summary: {e:#}");
1068 }
1069 }
1070 Err(e) => {
1071 tracing::warn!("Failed to generate summary embedding: {e:#}");
1072 }
1073 }
1074 }
1075
1076 if !structured.key_facts.is_empty() {
1078 self.store_key_facts(conversation_id, summary_id, &structured.key_facts)
1079 .await;
1080 }
1081
1082 Ok(Some(summary_id))
1083 }
1084
1085 async fn store_key_facts(
1086 &self,
1087 conversation_id: ConversationId,
1088 source_summary_id: i64,
1089 key_facts: &[String],
1090 ) {
1091 let Some(qdrant) = &self.qdrant else {
1092 return;
1093 };
1094 if !self.provider.supports_embeddings() {
1095 return;
1096 }
1097
1098 let Some(first_fact) = key_facts.first() else {
1099 return;
1100 };
1101 let first_vector = match self.provider.embed(first_fact).await {
1102 Ok(v) => v,
1103 Err(e) => {
1104 tracing::warn!("Failed to embed key fact: {e:#}");
1105 return;
1106 }
1107 };
1108 let vector_size = u64::try_from(first_vector.len()).unwrap_or(896);
1109 if let Err(e) = qdrant
1110 .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
1111 .await
1112 {
1113 tracing::warn!("Failed to ensure key_facts collection: {e:#}");
1114 return;
1115 }
1116
1117 let first_payload = serde_json::json!({
1118 "conversation_id": conversation_id.0,
1119 "fact_text": first_fact,
1120 "source_summary_id": source_summary_id,
1121 });
1122 if let Err(e) = qdrant
1123 .store_to_collection(KEY_FACTS_COLLECTION, first_payload, first_vector)
1124 .await
1125 {
1126 tracing::warn!("Failed to store key fact: {e:#}");
1127 }
1128
1129 for fact in &key_facts[1..] {
1130 match self.provider.embed(fact).await {
1131 Ok(vector) => {
1132 let payload = serde_json::json!({
1133 "conversation_id": conversation_id.0,
1134 "fact_text": fact,
1135 "source_summary_id": source_summary_id,
1136 });
1137 if let Err(e) = qdrant
1138 .store_to_collection(KEY_FACTS_COLLECTION, payload, vector)
1139 .await
1140 {
1141 tracing::warn!("Failed to store key fact: {e:#}");
1142 }
1143 }
1144 Err(e) => {
1145 tracing::warn!("Failed to embed key fact: {e:#}");
1146 }
1147 }
1148 }
1149 }
1150
1151 pub async fn search_key_facts(
1157 &self,
1158 query: &str,
1159 limit: usize,
1160 ) -> Result<Vec<String>, MemoryError> {
1161 let Some(qdrant) = &self.qdrant else {
1162 return Ok(Vec::new());
1163 };
1164 if !self.provider.supports_embeddings() {
1165 return Ok(Vec::new());
1166 }
1167
1168 let vector = self.provider.embed(query).await?;
1169 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
1170 qdrant
1171 .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
1172 .await?;
1173
1174 let points = qdrant
1175 .search_collection(KEY_FACTS_COLLECTION, &vector, limit, None)
1176 .await?;
1177
1178 let facts = points
1179 .into_iter()
1180 .filter_map(|p| p.payload.get("fact_text")?.as_str().map(String::from))
1181 .collect();
1182
1183 Ok(facts)
1184 }
1185
1186 pub async fn search_document_collection(
1196 &self,
1197 collection: &str,
1198 query: &str,
1199 limit: usize,
1200 ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
1201 let Some(qdrant) = &self.qdrant else {
1202 return Ok(Vec::new());
1203 };
1204 if !self.provider.supports_embeddings() {
1205 return Ok(Vec::new());
1206 }
1207 if !qdrant.collection_exists(collection).await? {
1208 return Ok(Vec::new());
1209 }
1210 let vector = self.provider.embed(query).await?;
1211 qdrant
1212 .search_collection(collection, &vector, limit, None)
1213 .await
1214 }
1215
1216 pub async fn store_correction_embedding(
1224 &self,
1225 correction_id: i64,
1226 correction_text: &str,
1227 ) -> Result<(), MemoryError> {
1228 let Some(ref store) = self.qdrant else {
1229 return Ok(());
1230 };
1231 if !self.provider.supports_embeddings() {
1232 return Ok(());
1233 }
1234 let embedding = self
1235 .provider
1236 .embed(correction_text)
1237 .await
1238 .map_err(|e| MemoryError::Other(e.to_string()))?;
1239 let payload = serde_json::json!({ "correction_id": correction_id });
1240 store
1241 .store_to_collection(CORRECTIONS_COLLECTION, payload, embedding)
1242 .await?;
1243 Ok(())
1244 }
1245
1246 pub async fn retrieve_similar_corrections(
1255 &self,
1256 query: &str,
1257 limit: usize,
1258 min_score: f32,
1259 ) -> Result<Vec<crate::sqlite::corrections::UserCorrectionRow>, MemoryError> {
1260 let Some(ref store) = self.qdrant else {
1261 return Ok(vec![]);
1262 };
1263 if !self.provider.supports_embeddings() {
1264 return Ok(vec![]);
1265 }
1266 let embedding = self
1267 .provider
1268 .embed(query)
1269 .await
1270 .map_err(|e| MemoryError::Other(e.to_string()))?;
1271 let scored = store
1272 .search_collection(CORRECTIONS_COLLECTION, &embedding, limit, None)
1273 .await
1274 .unwrap_or_default();
1275
1276 let mut results = Vec::new();
1277 for point in scored {
1278 if point.score < min_score {
1279 continue;
1280 }
1281 if let Some(id_val) = point.payload.get("correction_id")
1282 && let Some(id) = id_val.as_i64()
1283 {
1284 let rows = self.sqlite.load_corrections_for_id(id).await?;
1285 results.extend(rows);
1286 }
1287 }
1288 Ok(results)
1289 }
1290}
1291
1292#[cfg(test)]
1293mod tests {
1294 use zeph_llm::mock::MockProvider;
1295 use zeph_llm::provider::Role;
1296
1297 use super::*;
1298
1299 fn test_provider() -> AnyProvider {
1300 AnyProvider::Mock(MockProvider::default())
1301 }
1302
1303 async fn test_semantic_memory(_supports_embeddings: bool) -> SemanticMemory {
1304 let provider = test_provider();
1305 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1306
1307 SemanticMemory {
1308 sqlite,
1309 qdrant: None,
1310 provider,
1311 embedding_model: "test-model".into(),
1312 vector_weight: 0.7,
1313 keyword_weight: 0.3,
1314 temporal_decay_enabled: false,
1315 temporal_decay_half_life_days: 30,
1316 mmr_enabled: false,
1317 mmr_lambda: 0.7,
1318 token_counter: Arc::new(TokenCounter::new()),
1319 }
1320 }
1321
1322 #[tokio::test]
1323 async fn remember_saves_to_sqlite() {
1324 let memory = test_semantic_memory(false).await;
1325
1326 let cid = memory.sqlite.create_conversation().await.unwrap();
1327 let msg_id = memory.remember(cid, "user", "hello").await.unwrap();
1328
1329 assert_eq!(msg_id, MessageId(1));
1330
1331 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1332 assert_eq!(history.len(), 1);
1333 assert_eq!(history[0].role, Role::User);
1334 assert_eq!(history[0].content, "hello");
1335 }
1336
1337 #[tokio::test]
1338 async fn remember_with_parts_saves_parts_json() {
1339 let memory = test_semantic_memory(false).await;
1340 let cid = memory.sqlite.create_conversation().await.unwrap();
1341
1342 let parts_json =
1343 r#"[{"kind":"ToolOutput","tool_name":"shell","body":"hello","compacted_at":null}]"#;
1344 let (msg_id, _embedding_stored) = memory
1345 .remember_with_parts(cid, "assistant", "tool output", parts_json)
1346 .await
1347 .unwrap();
1348 assert!(msg_id > MessageId(0));
1349
1350 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1351 assert_eq!(history.len(), 1);
1352 assert_eq!(history[0].content, "tool output");
1353 }
1354
1355 #[tokio::test]
1356 async fn recall_returns_empty_without_qdrant() {
1357 let memory = test_semantic_memory(true).await;
1358
1359 let recalled = memory.recall("test", 5, None).await.unwrap();
1360 assert!(recalled.is_empty());
1361 }
1362
1363 #[tokio::test]
1364 async fn has_embedding_without_qdrant() {
1365 let memory = test_semantic_memory(true).await;
1366
1367 let has_embedding = memory.has_embedding(MessageId(1)).await.unwrap();
1368 assert!(!has_embedding);
1369 }
1370
1371 #[tokio::test]
1372 async fn embed_missing_without_qdrant() {
1373 let memory = test_semantic_memory(true).await;
1374
1375 let count = memory.embed_missing().await.unwrap();
1376 assert_eq!(count, 0);
1377 }
1378
1379 #[tokio::test]
1380 async fn sqlite_accessor() {
1381 let memory = test_semantic_memory(false).await;
1382
1383 let cid = memory.sqlite().create_conversation().await.unwrap();
1384 assert_eq!(cid, ConversationId(1));
1385
1386 memory
1387 .sqlite()
1388 .save_message(cid, "user", "test")
1389 .await
1390 .unwrap();
1391
1392 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1393 assert_eq!(history.len(), 1);
1394 }
1395
1396 #[tokio::test]
1397 async fn has_vector_store_returns_false_when_unavailable() {
1398 let memory = test_semantic_memory(false).await;
1399 assert!(!memory.has_vector_store());
1400 }
1401
1402 #[tokio::test]
1403 async fn is_vector_store_connected_returns_false_when_unavailable() {
1404 let memory = test_semantic_memory(false).await;
1405 assert!(!memory.is_vector_store_connected().await);
1406 }
1407
1408 #[tokio::test]
1409 async fn recall_returns_empty_when_embeddings_not_supported() {
1410 let memory = test_semantic_memory(false).await;
1411
1412 let recalled = memory.recall("test", 5, None).await.unwrap();
1413 assert!(recalled.is_empty());
1414 }
1415
1416 #[tokio::test]
1417 async fn embed_missing_returns_zero_when_embeddings_not_supported() {
1418 let memory = test_semantic_memory(false).await;
1419
1420 let cid = memory.sqlite().create_conversation().await.unwrap();
1421 memory
1422 .sqlite()
1423 .save_message(cid, "user", "test")
1424 .await
1425 .unwrap();
1426
1427 let count = memory.embed_missing().await.unwrap();
1428 assert_eq!(count, 0);
1429 }
1430
1431 #[tokio::test]
1432 async fn message_count_empty_conversation() {
1433 let memory = test_semantic_memory(false).await;
1434 let cid = memory.sqlite().create_conversation().await.unwrap();
1435
1436 let count = memory.message_count(cid).await.unwrap();
1437 assert_eq!(count, 0);
1438 }
1439
1440 #[tokio::test]
1441 async fn message_count_after_saves() {
1442 let memory = test_semantic_memory(false).await;
1443 let cid = memory.sqlite().create_conversation().await.unwrap();
1444
1445 memory.remember(cid, "user", "msg1").await.unwrap();
1446 memory.remember(cid, "assistant", "msg2").await.unwrap();
1447
1448 let count = memory.message_count(cid).await.unwrap();
1449 assert_eq!(count, 2);
1450 }
1451
1452 #[tokio::test]
1453 async fn unsummarized_count_decreases_after_summary() {
1454 let memory = test_semantic_memory(false).await;
1455 let cid = memory.sqlite().create_conversation().await.unwrap();
1456
1457 for i in 0..10 {
1458 memory
1459 .remember(cid, "user", &format!("msg{i}"))
1460 .await
1461 .unwrap();
1462 }
1463 assert_eq!(memory.unsummarized_message_count(cid).await.unwrap(), 10);
1464
1465 memory.summarize(cid, 5).await.unwrap();
1466
1467 assert!(memory.unsummarized_message_count(cid).await.unwrap() < 10);
1468 assert_eq!(memory.message_count(cid).await.unwrap(), 10);
1469 }
1470
1471 #[tokio::test]
1472 async fn load_summaries_empty() {
1473 let memory = test_semantic_memory(false).await;
1474 let cid = memory.sqlite().create_conversation().await.unwrap();
1475
1476 let summaries = memory.load_summaries(cid).await.unwrap();
1477 assert!(summaries.is_empty());
1478 }
1479
1480 #[tokio::test]
1481 async fn load_summaries_ordered() {
1482 let memory = test_semantic_memory(false).await;
1483 let cid = memory.sqlite().create_conversation().await.unwrap();
1484
1485 let msg_id1 = memory.remember(cid, "user", "m1").await.unwrap();
1486 let msg_id2 = memory.remember(cid, "assistant", "m2").await.unwrap();
1487 let msg_id3 = memory.remember(cid, "user", "m3").await.unwrap();
1488
1489 let s1 = memory
1490 .sqlite()
1491 .save_summary(cid, "summary1", msg_id1, msg_id2, 3)
1492 .await
1493 .unwrap();
1494 let s2 = memory
1495 .sqlite()
1496 .save_summary(cid, "summary2", msg_id2, msg_id3, 3)
1497 .await
1498 .unwrap();
1499
1500 let summaries = memory.load_summaries(cid).await.unwrap();
1501 assert_eq!(summaries.len(), 2);
1502 assert_eq!(summaries[0].id, s1);
1503 assert_eq!(summaries[0].content, "summary1");
1504 assert_eq!(summaries[1].id, s2);
1505 assert_eq!(summaries[1].content, "summary2");
1506 }
1507
1508 #[tokio::test]
1509 async fn summarize_below_threshold() {
1510 let memory = test_semantic_memory(false).await;
1511 let cid = memory.sqlite().create_conversation().await.unwrap();
1512
1513 memory.remember(cid, "user", "hello").await.unwrap();
1514
1515 let result = memory.summarize(cid, 10).await.unwrap();
1516 assert!(result.is_none());
1517 }
1518
1519 #[tokio::test]
1520 async fn summarize_stores_summary() {
1521 let memory = test_semantic_memory(false).await;
1522 let cid = memory.sqlite().create_conversation().await.unwrap();
1523
1524 for i in 0..5 {
1525 memory
1526 .remember(cid, "user", &format!("message {i}"))
1527 .await
1528 .unwrap();
1529 }
1530
1531 let summary_id = memory.summarize(cid, 3).await.unwrap();
1532 assert!(summary_id.is_some());
1533
1534 let summaries = memory.load_summaries(cid).await.unwrap();
1535 assert_eq!(summaries.len(), 1);
1536 assert_eq!(summaries[0].id, summary_id.unwrap());
1537 assert!(!summaries[0].content.is_empty());
1538 }
1539
1540 #[tokio::test]
1541 async fn summarize_respects_previous_summaries() {
1542 let memory = test_semantic_memory(false).await;
1543 let cid = memory.sqlite().create_conversation().await.unwrap();
1544
1545 for i in 0..10 {
1546 memory
1547 .remember(cid, "user", &format!("message {i}"))
1548 .await
1549 .unwrap();
1550 }
1551
1552 let s1 = memory.summarize(cid, 3).await.unwrap();
1553 assert!(s1.is_some());
1554
1555 let s2 = memory.summarize(cid, 3).await.unwrap();
1556 assert!(s2.is_some());
1557
1558 let summaries = memory.load_summaries(cid).await.unwrap();
1559 assert_eq!(summaries.len(), 2);
1560 assert!(summaries[0].last_message_id < summaries[1].first_message_id);
1561 }
1562
1563 #[tokio::test]
1564 async fn remember_multiple_messages_increments_ids() {
1565 let memory = test_semantic_memory(false).await;
1566 let cid = memory.sqlite.create_conversation().await.unwrap();
1567
1568 let id1 = memory.remember(cid, "user", "first").await.unwrap();
1569 let id2 = memory.remember(cid, "assistant", "second").await.unwrap();
1570 let id3 = memory.remember(cid, "user", "third").await.unwrap();
1571
1572 assert!(id1 < id2);
1573 assert!(id2 < id3);
1574 }
1575
1576 #[tokio::test]
1577 async fn message_count_across_conversations() {
1578 let memory = test_semantic_memory(false).await;
1579 let cid1 = memory.sqlite().create_conversation().await.unwrap();
1580 let cid2 = memory.sqlite().create_conversation().await.unwrap();
1581
1582 memory.remember(cid1, "user", "msg1").await.unwrap();
1583 memory.remember(cid1, "user", "msg2").await.unwrap();
1584 memory.remember(cid2, "user", "msg3").await.unwrap();
1585
1586 assert_eq!(memory.message_count(cid1).await.unwrap(), 2);
1587 assert_eq!(memory.message_count(cid2).await.unwrap(), 1);
1588 }
1589
1590 #[tokio::test]
1591 async fn summarize_exact_threshold_returns_none() {
1592 let memory = test_semantic_memory(false).await;
1593 let cid = memory.sqlite().create_conversation().await.unwrap();
1594
1595 for i in 0..3 {
1596 memory
1597 .remember(cid, "user", &format!("msg {i}"))
1598 .await
1599 .unwrap();
1600 }
1601
1602 let result = memory.summarize(cid, 3).await.unwrap();
1603 assert!(result.is_none());
1604 }
1605
1606 #[tokio::test]
1607 async fn summarize_one_above_threshold_produces_summary() {
1608 let memory = test_semantic_memory(false).await;
1609 let cid = memory.sqlite().create_conversation().await.unwrap();
1610
1611 for i in 0..4 {
1612 memory
1613 .remember(cid, "user", &format!("msg {i}"))
1614 .await
1615 .unwrap();
1616 }
1617
1618 let result = memory.summarize(cid, 3).await.unwrap();
1619 assert!(result.is_some());
1620 }
1621
1622 #[tokio::test]
1623 async fn summary_fields_populated() {
1624 let memory = test_semantic_memory(false).await;
1625 let cid = memory.sqlite().create_conversation().await.unwrap();
1626
1627 for i in 0..5 {
1628 memory
1629 .remember(cid, "user", &format!("msg {i}"))
1630 .await
1631 .unwrap();
1632 }
1633
1634 memory.summarize(cid, 3).await.unwrap();
1635 let summaries = memory.load_summaries(cid).await.unwrap();
1636 let s = &summaries[0];
1637
1638 assert_eq!(s.conversation_id, cid);
1639 assert!(s.first_message_id > MessageId(0));
1640 assert!(s.last_message_id >= s.first_message_id);
1641 assert!(s.token_estimate >= 0);
1642 assert!(!s.content.is_empty());
1643 }
1644
1645 #[test]
1646 fn build_summarization_prompt_format() {
1647 let messages = vec![
1648 (MessageId(1), "user".into(), "Hello".into()),
1649 (MessageId(2), "assistant".into(), "Hi there".into()),
1650 ];
1651 let prompt = build_summarization_prompt(&messages);
1652 assert!(prompt.contains("user: Hello"));
1653 assert!(prompt.contains("assistant: Hi there"));
1654 assert!(prompt.contains("key_facts"));
1655 }
1656
1657 #[test]
1658 fn build_summarization_prompt_empty() {
1659 let messages: Vec<(MessageId, String, String)> = vec![];
1660 let prompt = build_summarization_prompt(&messages);
1661 assert!(prompt.contains("key_facts"));
1662 }
1663
1664 #[test]
1665 fn structured_summary_deserialize() {
1666 let json = r#"{"summary":"s","key_facts":["f1","f2"],"entities":["e1"]}"#;
1667 let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1668 assert_eq!(ss.summary, "s");
1669 assert_eq!(ss.key_facts.len(), 2);
1670 assert_eq!(ss.entities.len(), 1);
1671 }
1672
1673 #[test]
1674 fn structured_summary_empty_facts() {
1675 let json = r#"{"summary":"s","key_facts":[],"entities":[]}"#;
1676 let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1677 assert!(ss.key_facts.is_empty());
1678 assert!(ss.entities.is_empty());
1679 }
1680
1681 #[tokio::test]
1682 async fn search_key_facts_no_qdrant_empty() {
1683 let memory = test_semantic_memory(false).await;
1684 let facts = memory.search_key_facts("query", 5).await.unwrap();
1685 assert!(facts.is_empty());
1686 }
1687
1688 #[test]
1689 fn recalled_message_debug() {
1690 let recalled = RecalledMessage {
1691 message: Message {
1692 role: Role::User,
1693 content: "test".into(),
1694 parts: vec![],
1695 metadata: MessageMetadata::default(),
1696 },
1697 score: 0.95,
1698 };
1699 let dbg = format!("{recalled:?}");
1700 assert!(dbg.contains("RecalledMessage"));
1701 assert!(dbg.contains("0.95"));
1702 }
1703
1704 #[test]
1705 fn summary_clone() {
1706 let summary = Summary {
1707 id: 1,
1708 conversation_id: ConversationId(2),
1709 content: "test summary".into(),
1710 first_message_id: MessageId(1),
1711 last_message_id: MessageId(5),
1712 token_estimate: 10,
1713 };
1714 let cloned = summary.clone();
1715 assert_eq!(summary.id, cloned.id);
1716 assert_eq!(summary.content, cloned.content);
1717 }
1718
1719 #[tokio::test]
1720 async fn remember_preserves_role_mapping() {
1721 let memory = test_semantic_memory(false).await;
1722 let cid = memory.sqlite.create_conversation().await.unwrap();
1723
1724 memory.remember(cid, "user", "u").await.unwrap();
1725 memory.remember(cid, "assistant", "a").await.unwrap();
1726 memory.remember(cid, "system", "s").await.unwrap();
1727
1728 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1729 assert_eq!(history.len(), 3);
1730 assert_eq!(history[0].role, Role::User);
1731 assert_eq!(history[1].role, Role::Assistant);
1732 assert_eq!(history[2].role, Role::System);
1733 }
1734
1735 #[tokio::test]
1736 async fn new_with_invalid_qdrant_url_graceful() {
1737 let mut mock = MockProvider::default();
1738 mock.supports_embeddings = true;
1739 let provider = AnyProvider::Mock(mock);
1740 let result =
1741 SemanticMemory::new(":memory:", "http://127.0.0.1:1", provider, "test-model").await;
1742 assert!(result.is_ok());
1743 }
1744
1745 #[tokio::test]
1746 async fn test_semantic_memory_sqlite_remember_recall_roundtrip() {
1747 let mut mock = MockProvider::default();
1749 mock.supports_embeddings = true;
1750 let provider = AnyProvider::Mock(mock);
1753
1754 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1755 let pool = sqlite.pool().clone();
1756 let qdrant = Some(crate::embedding_store::EmbeddingStore::new_sqlite(pool));
1757
1758 let memory = SemanticMemory {
1759 sqlite,
1760 qdrant,
1761 provider,
1762 embedding_model: "test-model".into(),
1763 vector_weight: 0.7,
1764 keyword_weight: 0.3,
1765 temporal_decay_enabled: false,
1766 temporal_decay_half_life_days: 30,
1767 mmr_enabled: false,
1768 mmr_lambda: 0.7,
1769 token_counter: Arc::new(TokenCounter::new()),
1770 };
1771
1772 let cid = memory.sqlite().create_conversation().await.unwrap();
1773
1774 let id1 = memory
1776 .remember(cid, "user", "rust async programming")
1777 .await
1778 .unwrap();
1779 let id2 = memory
1780 .remember(cid, "assistant", "use tokio for async")
1781 .await
1782 .unwrap();
1783 assert!(id1 < id2);
1784
1785 let recalled = memory.recall("rust", 5, None).await.unwrap();
1787 assert!(
1788 !recalled.is_empty(),
1789 "recall must return at least one result"
1790 );
1791
1792 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1794 assert_eq!(history.len(), 2);
1795 assert_eq!(history[0].content, "rust async programming");
1796 }
1797
1798 #[tokio::test]
1799 async fn remember_with_embeddings_supported_but_no_qdrant() {
1800 let memory = test_semantic_memory(true).await;
1801 let cid = memory.sqlite.create_conversation().await.unwrap();
1802
1803 let msg_id = memory.remember(cid, "user", "hello embed").await.unwrap();
1804 assert!(msg_id > MessageId(0));
1805
1806 let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1807 assert_eq!(history.len(), 1);
1808 assert_eq!(history[0].content, "hello embed");
1809 }
1810
1811 #[tokio::test]
1812 async fn remember_verifies_content_via_load_history() {
1813 let memory = test_semantic_memory(false).await;
1814 let cid = memory.sqlite.create_conversation().await.unwrap();
1815
1816 memory.remember(cid, "user", "alpha").await.unwrap();
1817 memory.remember(cid, "assistant", "beta").await.unwrap();
1818 memory.remember(cid, "user", "gamma").await.unwrap();
1819
1820 let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1821 assert_eq!(history.len(), 3);
1822 assert_eq!(history[0].content, "alpha");
1823 assert_eq!(history[1].content, "beta");
1824 assert_eq!(history[2].content, "gamma");
1825 }
1826
1827 #[tokio::test]
1828 async fn message_count_multiple_conversations_isolated() {
1829 let memory = test_semantic_memory(false).await;
1830 let cid1 = memory.sqlite().create_conversation().await.unwrap();
1831 let cid2 = memory.sqlite().create_conversation().await.unwrap();
1832 let cid3 = memory.sqlite().create_conversation().await.unwrap();
1833
1834 for _ in 0..5 {
1835 memory.remember(cid1, "user", "msg").await.unwrap();
1836 }
1837 for _ in 0..3 {
1838 memory.remember(cid2, "user", "msg").await.unwrap();
1839 }
1840
1841 assert_eq!(memory.message_count(cid1).await.unwrap(), 5);
1842 assert_eq!(memory.message_count(cid2).await.unwrap(), 3);
1843 assert_eq!(memory.message_count(cid3).await.unwrap(), 0);
1844 }
1845
1846 #[tokio::test]
1847 async fn summarize_empty_messages_range_returns_none() {
1848 let memory = test_semantic_memory(false).await;
1849 let cid = memory.sqlite().create_conversation().await.unwrap();
1850
1851 for i in 0..6 {
1852 memory
1853 .remember(cid, "user", &format!("msg {i}"))
1854 .await
1855 .unwrap();
1856 }
1857
1858 memory.summarize(cid, 3).await.unwrap();
1859 memory.summarize(cid, 3).await.unwrap();
1860
1861 let summaries = memory.load_summaries(cid).await.unwrap();
1862 assert_eq!(summaries.len(), 2);
1863 }
1864
1865 #[tokio::test]
1866 async fn summarize_token_estimate_populated() {
1867 let memory = test_semantic_memory(false).await;
1868 let cid = memory.sqlite().create_conversation().await.unwrap();
1869
1870 for i in 0..5 {
1871 memory
1872 .remember(cid, "user", &format!("message {i}"))
1873 .await
1874 .unwrap();
1875 }
1876
1877 memory.summarize(cid, 3).await.unwrap();
1878 let summaries = memory.load_summaries(cid).await.unwrap();
1879 let token_est = summaries[0].token_estimate;
1880 assert!(token_est > 0);
1881 }
1882
1883 #[tokio::test]
1884 async fn summarize_fails_when_provider_chat_fails() {
1885 let sqlite = SqliteStore::new(":memory:").await.unwrap();
1886 let provider = AnyProvider::Ollama(zeph_llm::ollama::OllamaProvider::new(
1887 "http://127.0.0.1:1",
1888 "test".into(),
1889 "embed".into(),
1890 ));
1891 let memory = SemanticMemory {
1892 sqlite,
1893 qdrant: None,
1894 provider,
1895 embedding_model: "test".into(),
1896 vector_weight: 0.7,
1897 keyword_weight: 0.3,
1898 temporal_decay_enabled: false,
1899 temporal_decay_half_life_days: 30,
1900 mmr_enabled: false,
1901 mmr_lambda: 0.7,
1902 token_counter: Arc::new(TokenCounter::new()),
1903 };
1904 let cid = memory.sqlite().create_conversation().await.unwrap();
1905
1906 for i in 0..5 {
1907 memory
1908 .remember(cid, "user", &format!("msg {i}"))
1909 .await
1910 .unwrap();
1911 }
1912
1913 let result = memory.summarize(cid, 3).await;
1914 assert!(result.is_err());
1915 }
1916
1917 #[tokio::test]
1918 async fn embed_missing_without_embedding_support_returns_zero() {
1919 let memory = test_semantic_memory(false).await;
1920 let cid = memory.sqlite().create_conversation().await.unwrap();
1921 memory
1922 .sqlite()
1923 .save_message(cid, "user", "test message")
1924 .await
1925 .unwrap();
1926
1927 let count = memory.embed_missing().await.unwrap();
1928 assert_eq!(count, 0);
1929 }
1930
1931 #[tokio::test]
1932 async fn has_embedding_returns_false_when_no_qdrant() {
1933 let memory = test_semantic_memory(false).await;
1934 let cid = memory.sqlite.create_conversation().await.unwrap();
1935 let msg_id = memory.remember(cid, "user", "test").await.unwrap();
1936 assert!(!memory.has_embedding(msg_id).await.unwrap());
1937 }
1938
1939 #[tokio::test]
1940 async fn recall_empty_without_qdrant_regardless_of_filter() {
1941 let memory = test_semantic_memory(true).await;
1942 let filter = SearchFilter {
1943 conversation_id: Some(ConversationId(1)),
1944 role: None,
1945 };
1946 let recalled = memory.recall("query", 10, Some(filter)).await.unwrap();
1947 assert!(recalled.is_empty());
1948 }
1949
1950 #[tokio::test]
1951 async fn summarize_message_range_bounds() {
1952 let memory = test_semantic_memory(false).await;
1953 let cid = memory.sqlite().create_conversation().await.unwrap();
1954
1955 for i in 0..8 {
1956 memory
1957 .remember(cid, "user", &format!("msg {i}"))
1958 .await
1959 .unwrap();
1960 }
1961
1962 let summary_id = memory.summarize(cid, 4).await.unwrap().unwrap();
1963 let summaries = memory.load_summaries(cid).await.unwrap();
1964 assert_eq!(summaries.len(), 1);
1965 assert_eq!(summaries[0].id, summary_id);
1966 assert!(summaries[0].first_message_id >= MessageId(1));
1967 assert!(summaries[0].last_message_id >= summaries[0].first_message_id);
1968 }
1969
1970 #[test]
1971 fn build_summarization_prompt_preserves_order() {
1972 let messages = vec![
1973 (MessageId(1), "user".into(), "first".into()),
1974 (MessageId(2), "assistant".into(), "second".into()),
1975 (MessageId(3), "user".into(), "third".into()),
1976 ];
1977 let prompt = build_summarization_prompt(&messages);
1978 let first_pos = prompt.find("user: first").unwrap();
1979 let second_pos = prompt.find("assistant: second").unwrap();
1980 let third_pos = prompt.find("user: third").unwrap();
1981 assert!(first_pos < second_pos);
1982 assert!(second_pos < third_pos);
1983 }
1984
1985 #[test]
1986 fn summary_debug() {
1987 let summary = Summary {
1988 id: 1,
1989 conversation_id: ConversationId(2),
1990 content: "test".into(),
1991 first_message_id: MessageId(1),
1992 last_message_id: MessageId(5),
1993 token_estimate: 10,
1994 };
1995 let dbg = format!("{summary:?}");
1996 assert!(dbg.contains("Summary"));
1997 }
1998
1999 #[tokio::test]
2000 async fn message_count_nonexistent_conversation() {
2001 let memory = test_semantic_memory(false).await;
2002 let count = memory.message_count(ConversationId(999)).await.unwrap();
2003 assert_eq!(count, 0);
2004 }
2005
2006 #[tokio::test]
2007 async fn load_summaries_nonexistent_conversation() {
2008 let memory = test_semantic_memory(false).await;
2009 let summaries = memory.load_summaries(ConversationId(999)).await.unwrap();
2010 assert!(summaries.is_empty());
2011 }
2012
2013 #[tokio::test]
2014 async fn store_session_summary_no_qdrant_noop() {
2015 let memory = test_semantic_memory(true).await;
2016 let result = memory
2017 .store_session_summary(ConversationId(1), "test summary")
2018 .await;
2019 assert!(result.is_ok());
2020 }
2021
2022 #[tokio::test]
2023 async fn store_session_summary_no_embeddings_noop() {
2024 let memory = test_semantic_memory(false).await;
2025 let result = memory
2026 .store_session_summary(ConversationId(1), "test summary")
2027 .await;
2028 assert!(result.is_ok());
2029 }
2030
2031 #[tokio::test]
2032 async fn search_session_summaries_no_qdrant_empty() {
2033 let memory = test_semantic_memory(true).await;
2034 let results = memory
2035 .search_session_summaries("query", 5, None)
2036 .await
2037 .unwrap();
2038 assert!(results.is_empty());
2039 }
2040
2041 #[tokio::test]
2042 async fn search_session_summaries_no_embeddings_empty() {
2043 let memory = test_semantic_memory(false).await;
2044 let results = memory
2045 .search_session_summaries("query", 5, Some(ConversationId(1)))
2046 .await
2047 .unwrap();
2048 assert!(results.is_empty());
2049 }
2050
2051 #[test]
2052 fn session_summary_result_debug() {
2053 let result = SessionSummaryResult {
2054 summary_text: "test".into(),
2055 score: 0.9,
2056 conversation_id: ConversationId(1),
2057 };
2058 let dbg = format!("{result:?}");
2059 assert!(dbg.contains("SessionSummaryResult"));
2060 }
2061
2062 #[test]
2063 fn session_summary_result_clone() {
2064 let result = SessionSummaryResult {
2065 summary_text: "test".into(),
2066 score: 0.9,
2067 conversation_id: ConversationId(1),
2068 };
2069 let cloned = result.clone();
2070 assert_eq!(result.summary_text, cloned.summary_text);
2071 assert_eq!(result.conversation_id, cloned.conversation_id);
2072 }
2073
2074 #[tokio::test]
2075 async fn recall_fts5_fallback_without_qdrant() {
2076 let memory = test_semantic_memory(false).await;
2077 let cid = memory.sqlite.create_conversation().await.unwrap();
2078
2079 memory
2080 .remember(cid, "user", "rust programming guide")
2081 .await
2082 .unwrap();
2083 memory
2084 .remember(cid, "assistant", "python tutorial")
2085 .await
2086 .unwrap();
2087 memory
2088 .remember(cid, "user", "advanced rust patterns")
2089 .await
2090 .unwrap();
2091
2092 let recalled = memory.recall("rust", 5, None).await.unwrap();
2093 assert_eq!(recalled.len(), 2);
2094 assert!(recalled[0].score >= recalled[1].score);
2095 }
2096
2097 #[tokio::test]
2098 async fn recall_fts5_fallback_with_filter() {
2099 let memory = test_semantic_memory(false).await;
2100 let cid1 = memory.sqlite.create_conversation().await.unwrap();
2101 let cid2 = memory.sqlite.create_conversation().await.unwrap();
2102
2103 memory.remember(cid1, "user", "hello world").await.unwrap();
2104 memory
2105 .remember(cid2, "user", "hello universe")
2106 .await
2107 .unwrap();
2108
2109 let filter = SearchFilter {
2110 conversation_id: Some(cid1),
2111 role: None,
2112 };
2113 let recalled = memory.recall("hello", 5, Some(filter)).await.unwrap();
2114 assert_eq!(recalled.len(), 1);
2115 }
2116
2117 #[tokio::test]
2118 async fn recall_fts5_no_matches_returns_empty() {
2119 let memory = test_semantic_memory(false).await;
2120 let cid = memory.sqlite.create_conversation().await.unwrap();
2121
2122 memory.remember(cid, "user", "hello world").await.unwrap();
2123
2124 let recalled = memory.recall("nonexistent", 5, None).await.unwrap();
2125 assert!(recalled.is_empty());
2126 }
2127
2128 #[tokio::test]
2129 async fn recall_fts5_respects_limit() {
2130 let memory = test_semantic_memory(false).await;
2131 let cid = memory.sqlite.create_conversation().await.unwrap();
2132
2133 for i in 0..10 {
2134 memory
2135 .remember(cid, "user", &format!("test message number {i}"))
2136 .await
2137 .unwrap();
2138 }
2139
2140 let recalled = memory.recall("test", 3, None).await.unwrap();
2141 assert_eq!(recalled.len(), 3);
2142 }
2143
2144 #[tokio::test]
2147 async fn summarize_fallback_to_plain_text_when_structured_fails() {
2148 let sqlite = SqliteStore::new(":memory:").await.unwrap();
2156 let mut mock = MockProvider::default();
2157 mock.default_response = "plain text summary".into();
2159 let provider = AnyProvider::Mock(mock);
2160
2161 let memory = SemanticMemory {
2162 sqlite,
2163 qdrant: None,
2164 provider,
2165 embedding_model: "test".into(),
2166 vector_weight: 0.7,
2167 keyword_weight: 0.3,
2168 temporal_decay_enabled: false,
2169 temporal_decay_half_life_days: 30,
2170 mmr_enabled: false,
2171 mmr_lambda: 0.7,
2172 token_counter: Arc::new(TokenCounter::new()),
2173 };
2174
2175 let cid = memory.sqlite().create_conversation().await.unwrap();
2176 for i in 0..5 {
2177 memory
2178 .remember(cid, "user", &format!("msg {i}"))
2179 .await
2180 .unwrap();
2181 }
2182
2183 let result = memory.summarize(cid, 3).await;
2184 assert!(result.is_ok());
2190 let summaries = memory.load_summaries(cid).await.unwrap();
2191 assert_eq!(summaries.len(), 1);
2192 assert!(!summaries[0].content.is_empty());
2193 }
2194
2195 #[test]
2198 fn temporal_decay_disabled_leaves_scores_unchanged() {
2199 let mut ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
2200 let timestamps = std::collections::HashMap::new();
2201 apply_temporal_decay(&mut ranked, ×tamps, 30);
2202 assert!((ranked[0].1 - 1.0).abs() < f64::EPSILON);
2203 assert!((ranked[1].1 - 0.5).abs() < f64::EPSILON);
2204 }
2205
2206 #[test]
2207 fn temporal_decay_zero_age_preserves_score() {
2208 let now = std::time::SystemTime::now()
2209 .duration_since(std::time::UNIX_EPOCH)
2210 .unwrap_or_default()
2211 .as_secs()
2212 .cast_signed();
2213 let mut ranked = vec![(MessageId(1), 1.0f64)];
2214 let mut timestamps = std::collections::HashMap::new();
2215 timestamps.insert(MessageId(1), now);
2216 apply_temporal_decay(&mut ranked, ×tamps, 30);
2217 assert!((ranked[0].1 - 1.0).abs() < 0.01);
2219 }
2220
2221 #[test]
2222 fn temporal_decay_half_life_halves_score() {
2223 let half_life = 30u32;
2225 let age_secs = i64::from(half_life) * 86400;
2226 let now = std::time::SystemTime::now()
2227 .duration_since(std::time::UNIX_EPOCH)
2228 .unwrap_or_default()
2229 .as_secs()
2230 .cast_signed();
2231 let ts = now - age_secs;
2232 let mut ranked = vec![(MessageId(1), 1.0f64)];
2233 let mut timestamps = std::collections::HashMap::new();
2234 timestamps.insert(MessageId(1), ts);
2235 apply_temporal_decay(&mut ranked, ×tamps, half_life);
2236 assert!(
2238 (ranked[0].1 - 0.5).abs() < 0.01,
2239 "score was {}",
2240 ranked[0].1
2241 );
2242 }
2243
2244 #[test]
2247 fn mmr_empty_input_returns_empty() {
2248 let ranked = vec![];
2249 let vectors = std::collections::HashMap::new();
2250 let result = apply_mmr(&ranked, &vectors, 0.7, 5);
2251 assert!(result.is_empty());
2252 }
2253
2254 #[test]
2255 fn mmr_returns_up_to_limit() {
2256 let ranked = vec![
2257 (MessageId(1), 1.0f64),
2258 (MessageId(2), 0.9f64),
2259 (MessageId(3), 0.8f64),
2260 ];
2261 let mut vectors = std::collections::HashMap::new();
2262 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2263 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2264 vectors.insert(MessageId(3), vec![1.0f32, 0.0]);
2265 let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2266 assert_eq!(result.len(), 2);
2267 }
2268
2269 #[test]
2270 fn mmr_without_vectors_picks_by_relevance() {
2271 let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
2272 let vectors = std::collections::HashMap::new();
2273 let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2274 assert_eq!(result.len(), 2);
2275 assert_eq!(result[0].0, MessageId(1));
2276 }
2277
2278 #[test]
2279 fn mmr_prefers_diverse_over_redundant() {
2280 let ranked = vec![
2282 (MessageId(1), 1.0f64), (MessageId(2), 0.9f64), (MessageId(3), 0.9f64), ];
2286 let mut vectors = std::collections::HashMap::new();
2287 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2288 vectors.insert(MessageId(2), vec![0.0f32, 1.0]); vectors.insert(MessageId(3), vec![1.0f32, 0.0]); let result = apply_mmr(&ranked, &vectors, 0.5, 2);
2291 assert_eq!(result.len(), 2);
2292 assert_eq!(result[0].0, MessageId(1));
2293 assert_eq!(result[1].0, MessageId(2));
2295 }
2296
2297 #[test]
2298 fn temporal_decay_half_life_zero_is_noop() {
2299 let now = std::time::SystemTime::now()
2300 .duration_since(std::time::UNIX_EPOCH)
2301 .unwrap_or_default()
2302 .as_secs()
2303 .cast_signed();
2304 let age_secs = 30i64 * 86400;
2305 let ts = now - age_secs;
2306 let mut ranked = vec![(MessageId(1), 1.0f64)];
2307 let mut timestamps = std::collections::HashMap::new();
2308 timestamps.insert(MessageId(1), ts);
2309 apply_temporal_decay(&mut ranked, ×tamps, 0);
2311 assert!(
2312 (ranked[0].1 - 1.0).abs() < f64::EPSILON,
2313 "score was {}",
2314 ranked[0].1
2315 );
2316 }
2317
2318 #[test]
2319 fn temporal_decay_huge_age_near_zero() {
2320 let now = std::time::SystemTime::now()
2321 .duration_since(std::time::UNIX_EPOCH)
2322 .unwrap_or_default()
2323 .as_secs()
2324 .cast_signed();
2325 let age_secs = 3650i64 * 86400;
2327 let ts = now - age_secs;
2328 let mut ranked = vec![(MessageId(1), 1.0f64)];
2329 let mut timestamps = std::collections::HashMap::new();
2330 timestamps.insert(MessageId(1), ts);
2331 apply_temporal_decay(&mut ranked, ×tamps, 30);
2332 assert!(ranked[0].1 < 0.001, "score was {}", ranked[0].1);
2334 }
2335
2336 #[test]
2337 fn temporal_decay_small_half_life() {
2338 let now = std::time::SystemTime::now()
2340 .duration_since(std::time::UNIX_EPOCH)
2341 .unwrap_or_default()
2342 .as_secs()
2343 .cast_signed();
2344 let ts = now - 7 * 86400i64;
2345 let mut ranked = vec![(MessageId(1), 1.0f64)];
2346 let mut timestamps = std::collections::HashMap::new();
2347 timestamps.insert(MessageId(1), ts);
2348 apply_temporal_decay(&mut ranked, ×tamps, 1);
2349 assert!(ranked[0].1 < 0.01, "score was {}", ranked[0].1);
2350 }
2351
2352 #[test]
2353 fn mmr_lambda_zero_max_diversity() {
2354 let ranked = vec![
2356 (MessageId(1), 1.0f64), (MessageId(2), 0.9f64), (MessageId(3), 0.85f64), ];
2360 let mut vectors = std::collections::HashMap::new();
2361 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2362 vectors.insert(MessageId(2), vec![0.0f32, 1.0]); vectors.insert(MessageId(3), vec![1.0f32, 0.0]); let result = apply_mmr(&ranked, &vectors, 0.0, 3);
2365 assert_eq!(result.len(), 3);
2366 assert_eq!(result[1].0, MessageId(2));
2368 }
2369
2370 #[test]
2371 fn mmr_lambda_one_pure_relevance() {
2372 let ranked = vec![
2374 (MessageId(1), 1.0f64),
2375 (MessageId(2), 0.8f64),
2376 (MessageId(3), 0.6f64),
2377 ];
2378 let mut vectors = std::collections::HashMap::new();
2379 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2380 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2381 vectors.insert(MessageId(3), vec![0.5f32, 0.5]);
2382 let result = apply_mmr(&ranked, &vectors, 1.0, 3);
2383 assert_eq!(result.len(), 3);
2384 assert_eq!(result[0].0, MessageId(1));
2385 assert_eq!(result[1].0, MessageId(2));
2386 assert_eq!(result[2].0, MessageId(3));
2387 }
2388
2389 #[test]
2390 fn mmr_limit_zero_returns_empty() {
2391 let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.8f64)];
2392 let mut vectors = std::collections::HashMap::new();
2393 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2394 vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2395 let result = apply_mmr(&ranked, &vectors, 0.7, 0);
2396 assert!(result.is_empty());
2397 }
2398
2399 #[test]
2400 fn mmr_duplicate_vectors_penalizes_second() {
2401 let ranked = vec![
2403 (MessageId(1), 1.0f64),
2404 (MessageId(2), 1.0f64), (MessageId(3), 0.9f64), ];
2407 let mut vectors = std::collections::HashMap::new();
2408 vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2409 vectors.insert(MessageId(2), vec![1.0f32, 0.0]); vectors.insert(MessageId(3), vec![0.0f32, 1.0]); let result = apply_mmr(&ranked, &vectors, 0.5, 3);
2412 assert_eq!(result.len(), 3);
2413 assert_eq!(result[0].0, MessageId(1));
2414 assert_eq!(result[1].0, MessageId(3));
2416 }
2417
2418 #[tokio::test]
2421 async fn recall_routed_keyword_route_returns_fts5_results() {
2422 use crate::{HeuristicRouter, MemoryRoute, MemoryRouter};
2423
2424 let memory = test_semantic_memory(false).await;
2425 let cid = memory.sqlite.create_conversation().await.unwrap();
2426
2427 memory
2428 .remember(cid, "user", "rust programming guide")
2429 .await
2430 .unwrap();
2431 memory
2432 .remember(cid, "assistant", "python tutorial")
2433 .await
2434 .unwrap();
2435
2436 let router = HeuristicRouter;
2438 assert_eq!(router.route("rust_guide"), MemoryRoute::Keyword);
2439
2440 let recalled = memory
2441 .recall_routed("rust_guide", 5, None, &router)
2442 .await
2443 .unwrap();
2444 assert!(recalled.len() <= 2);
2446 }
2447
2448 #[tokio::test]
2449 async fn recall_routed_semantic_route_without_qdrant_returns_empty_vectors() {
2450 use crate::{HeuristicRouter, MemoryRoute, MemoryRouter};
2451
2452 let memory = test_semantic_memory(false).await;
2453 let cid = memory.sqlite.create_conversation().await.unwrap();
2454
2455 memory
2456 .remember(cid, "user", "how does the agent loop work")
2457 .await
2458 .unwrap();
2459
2460 let router = HeuristicRouter;
2462 assert_eq!(
2463 router.route("how does the agent loop work"),
2464 MemoryRoute::Semantic
2465 );
2466
2467 let recalled = memory
2469 .recall_routed("how does the agent loop work", 5, None, &router)
2470 .await
2471 .unwrap();
2472 assert!(recalled.is_empty(), "no Qdrant → empty semantic recall");
2473 }
2474
2475 #[tokio::test]
2476 async fn recall_routed_hybrid_route_falls_back_to_fts5_on_no_qdrant() {
2477 use crate::{HeuristicRouter, MemoryRoute, MemoryRouter};
2478
2479 let memory = test_semantic_memory(false).await;
2480 let cid = memory.sqlite.create_conversation().await.unwrap();
2481
2482 memory
2483 .remember(cid, "user", "context window token budget")
2484 .await
2485 .unwrap();
2486
2487 let router = HeuristicRouter;
2489 assert_eq!(
2490 router.route("context window token budget"),
2491 MemoryRoute::Hybrid
2492 );
2493
2494 let recalled = memory
2496 .recall_routed("context window token budget", 5, None, &router)
2497 .await
2498 .unwrap();
2499 assert!(!recalled.is_empty(), "FTS5 should find the stored message");
2501 }
2502
2503 use proptest::prelude::*;
2506
2507 proptest! {
2508 #[test]
2509 fn count_tokens_never_panics(s in ".*") {
2510 let counter = crate::token_counter::TokenCounter::new();
2511 let _ = counter.count_tokens(&s);
2512 }
2513 }
2514}