zeph_memory/semantic/
corrections.rs1use std::time::Duration;
5
6use zeph_llm::provider::LlmProvider as _;
7
8use crate::error::MemoryError;
9
10use super::{CORRECTIONS_COLLECTION, SemanticMemory};
11
12impl SemanticMemory {
13 pub async fn store_correction_embedding(
21 &self,
22 correction_id: i64,
23 correction_text: &str,
24 ) -> Result<(), MemoryError> {
25 let Some(ref store) = self.qdrant else {
26 return Ok(());
27 };
28 if !self.effective_embed_provider().supports_embeddings() {
29 return Ok(());
30 }
31 let embedding = match tokio::time::timeout(
32 Duration::from_secs(5),
33 self.effective_embed_provider().embed(correction_text),
34 )
35 .await
36 {
37 Ok(Ok(v)) => v,
38 Ok(Err(e)) => return Err(MemoryError::Llm(e)),
39 Err(_) => {
40 tracing::warn!("corrections: embed timed out, skipping vector store write");
41 return Ok(());
42 }
43 };
44 let vector_size = u64::try_from(embedding.len()).unwrap_or(896);
45 store
46 .ensure_named_collection(CORRECTIONS_COLLECTION, vector_size)
47 .await?;
48 let payload = serde_json::json!({ "correction_id": correction_id });
49 store
50 .store_to_collection(CORRECTIONS_COLLECTION, payload, embedding)
51 .await?;
52 Ok(())
53 }
54
55 pub async fn retrieve_similar_corrections(
64 &self,
65 query: &str,
66 limit: usize,
67 min_score: f32,
68 ) -> Result<Vec<crate::store::corrections::UserCorrectionRow>, MemoryError> {
69 let Some(ref store) = self.qdrant else {
70 tracing::debug!("corrections: skipped, no vector store");
71 return Ok(vec![]);
72 };
73 if !self.effective_embed_provider().supports_embeddings() {
74 tracing::debug!("corrections: skipped, no embedding support");
75 return Ok(vec![]);
76 }
77 let embedding = self
78 .effective_embed_provider()
79 .embed(query)
80 .await
81 .map_err(MemoryError::Llm)?;
82 let vector_size = u64::try_from(embedding.len()).unwrap_or(896);
83 store
84 .ensure_named_collection(CORRECTIONS_COLLECTION, vector_size)
85 .await?;
86 let scored = store
87 .search_collection(CORRECTIONS_COLLECTION, &embedding, limit, None)
88 .await
89 .unwrap_or_default();
90
91 tracing::debug!(
92 candidates = scored.len(),
93 min_score = %min_score,
94 limit,
95 "corrections: search complete"
96 );
97
98 let mut results = Vec::new();
99 for point in scored {
100 if point.score < min_score {
101 continue;
102 }
103 if let Some(id_val) = point.payload.get("correction_id")
104 && let Some(id) = id_val.as_i64()
105 {
106 let rows = self.sqlite.load_corrections_for_id(id).await?;
107 results.extend(rows);
108 }
109 }
110
111 tracing::debug!(
112 retained = results.len(),
113 "corrections: after min_score filter"
114 );
115
116 Ok(results)
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use std::sync::Arc;
123
124 use zeph_llm::any::AnyProvider;
125 use zeph_llm::mock::MockProvider;
126
127 use crate::embedding_store::EmbeddingStore;
128 use crate::in_memory_store::InMemoryVectorStore;
129 use crate::semantic::SemanticMemory;
130 use crate::store::SqliteStore;
131 use crate::token_counter::TokenCounter;
132
133 async fn mem_with_slow_embed(embed_delay_ms: u64) -> SemanticMemory {
134 let sqlite = SqliteStore::new(":memory:").await.unwrap();
135 let pool = sqlite.pool().clone();
136 let qdrant = EmbeddingStore::with_store(Box::new(InMemoryVectorStore::new()), pool);
137 let base_provider = AnyProvider::Mock(MockProvider::default());
138 let slow_embed =
139 AnyProvider::Mock(MockProvider::default().with_embed_delay(embed_delay_ms));
140 SemanticMemory::from_parts(
141 sqlite,
142 Some(Arc::new(qdrant)),
143 base_provider,
144 "test-model",
145 0.7,
146 0.3,
147 Arc::new(TokenCounter::new()),
148 )
149 .with_embed_provider(slow_embed)
150 }
151
152 #[tokio::test]
154 async fn store_correction_embedding_embed_timeout_is_ok() {
155 let mem = mem_with_slow_embed(10_000).await;
157
158 tokio::time::pause();
159
160 let fut = mem.store_correction_embedding(42, "I prefer detailed answers");
161 let (result, ()) = tokio::join!(fut, async {
162 tokio::time::advance(std::time::Duration::from_secs(6)).await;
163 });
164
165 assert!(
166 result.is_ok(),
167 "embed timeout must return Ok(()) (fail-open, skip write), got {result:?}"
168 );
169 }
170}