zeph_memory/semantic/
corrections.rs1use zeph_llm::provider::LlmProvider as _;
5
6use crate::error::MemoryError;
7
8use super::{CORRECTIONS_COLLECTION, SemanticMemory};
9
10impl SemanticMemory {
11 pub async fn store_correction_embedding(
19 &self,
20 correction_id: i64,
21 correction_text: &str,
22 ) -> Result<(), MemoryError> {
23 let Some(ref store) = self.qdrant else {
24 return Ok(());
25 };
26 if !self.effective_embed_provider().supports_embeddings() {
27 return Ok(());
28 }
29 let embedding = match tokio::time::timeout(
30 self.embed_timeout,
31 self.effective_embed_provider().embed(correction_text),
32 )
33 .await
34 {
35 Ok(Ok(v)) => v,
36 Ok(Err(e)) => return Err(MemoryError::Llm(e)),
37 Err(_) => {
38 tracing::warn!("corrections: embed timed out, skipping vector store write");
39 return Ok(());
40 }
41 };
42 let vector_size = u64::try_from(embedding.len()).unwrap_or(896);
43 store
44 .ensure_named_collection(CORRECTIONS_COLLECTION, vector_size)
45 .await?;
46 let payload = serde_json::json!({ "correction_id": correction_id });
47 store
48 .store_to_collection(CORRECTIONS_COLLECTION, payload, embedding)
49 .await?;
50 Ok(())
51 }
52
53 pub async fn retrieve_similar_corrections(
62 &self,
63 query: &str,
64 limit: usize,
65 min_score: f32,
66 ) -> Result<Vec<crate::store::corrections::UserCorrectionRow>, MemoryError> {
67 let Some(ref store) = self.qdrant else {
68 tracing::debug!("corrections: skipped, no vector store");
69 return Ok(vec![]);
70 };
71 if !self.effective_embed_provider().supports_embeddings() {
72 tracing::debug!("corrections: skipped, no embedding support");
73 return Ok(vec![]);
74 }
75 let embedding = match tokio::time::timeout(
76 self.embed_timeout,
77 self.effective_embed_provider().embed(query),
78 )
79 .await
80 {
81 Ok(Ok(v)) => v,
82 Ok(Err(e)) => return Err(MemoryError::Llm(e)),
83 Err(_) => {
84 tracing::warn!("search_corrections: embed() timed out, returning empty");
85 return Ok(vec![]);
86 }
87 };
88 let vector_size = u64::try_from(embedding.len()).unwrap_or(896);
89 store
90 .ensure_named_collection(CORRECTIONS_COLLECTION, vector_size)
91 .await?;
92 let scored = store
93 .search_collection(CORRECTIONS_COLLECTION, &embedding, limit, None)
94 .await
95 .unwrap_or_default();
96
97 tracing::debug!(
98 candidates = scored.len(),
99 min_score = %min_score,
100 limit,
101 "corrections: search complete"
102 );
103
104 let mut results = Vec::new();
105 for point in scored {
106 if point.score < min_score {
107 continue;
108 }
109 if let Some(id_val) = point.payload.get("correction_id")
110 && let Some(id) = id_val.as_i64()
111 {
112 let rows = self.sqlite.load_corrections_for_id(id).await?;
113 results.extend(rows);
114 }
115 }
116
117 tracing::debug!(
118 retained = results.len(),
119 "corrections: after min_score filter"
120 );
121
122 Ok(results)
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use std::sync::Arc;
129
130 use zeph_llm::any::AnyProvider;
131 use zeph_llm::mock::MockProvider;
132
133 use crate::embedding_store::EmbeddingStore;
134 use crate::in_memory_store::InMemoryVectorStore;
135 use crate::semantic::SemanticMemory;
136 use crate::store::SqliteStore;
137 use crate::token_counter::TokenCounter;
138
139 async fn mem_with_slow_embed(embed_delay_ms: u64) -> SemanticMemory {
140 let sqlite = SqliteStore::new(":memory:").await.unwrap();
141 let pool = sqlite.pool().clone();
142 let qdrant = EmbeddingStore::with_store(Box::new(InMemoryVectorStore::new()), pool);
143 let base_provider = AnyProvider::Mock(MockProvider::default());
144 let slow_embed =
145 AnyProvider::Mock(MockProvider::default().with_embed_delay(embed_delay_ms));
146 SemanticMemory::from_parts(
147 sqlite,
148 Some(Arc::new(qdrant)),
149 base_provider,
150 "test-model",
151 0.7,
152 0.3,
153 Arc::new(TokenCounter::new()),
154 )
155 .with_embedding_provider(slow_embed)
156 }
157
158 #[tokio::test]
160 async fn store_correction_embedding_embed_timeout_is_ok() {
161 let mem = mem_with_slow_embed(10_000).await;
163
164 tokio::time::pause();
165
166 let fut = mem.store_correction_embedding(42, "I prefer detailed answers");
167 let (result, ()) = tokio::join!(fut, async {
168 tokio::time::advance(std::time::Duration::from_secs(6)).await;
169 });
170
171 assert!(
172 result.is_ok(),
173 "embed timeout must return Ok(()) (fail-open, skip write), got {result:?}"
174 );
175 }
176
177 #[tokio::test]
179 async fn retrieve_similar_corrections_embed_timeout_returns_empty() {
180 let mem = mem_with_slow_embed(10_000).await;
181
182 tokio::time::pause();
183
184 let fut = mem.retrieve_similar_corrections("prefer concise answers", 5, 0.7);
185 let (result, ()) = tokio::join!(fut, async {
186 tokio::time::advance(std::time::Duration::from_secs(6)).await;
187 });
188
189 match result {
190 Ok(rows) => assert!(
191 rows.is_empty(),
192 "embed timeout must return empty vec (fail-open), got {rows:?}"
193 ),
194 Err(e) => panic!("embed timeout must not propagate error, got {e:?}"),
195 }
196 }
197}