Skip to main content

zeph_memory/semantic/
corrections.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_llm::provider::LlmProvider as _;
5
6use crate::error::MemoryError;
7
8use super::{CORRECTIONS_COLLECTION, SemanticMemory};
9
10impl SemanticMemory {
11    /// Store an embedding for a user correction in the vector store.
12    ///
13    /// Silently skips if no vector store is configured or embeddings are unsupported.
14    ///
15    /// # Errors
16    ///
17    /// Returns an error if embedding generation or vector store write fails.
18    pub async fn store_correction_embedding(
19        &self,
20        correction_id: i64,
21        correction_text: &str,
22    ) -> Result<(), MemoryError> {
23        let Some(ref store) = self.qdrant else {
24            return Ok(());
25        };
26        if !self.effective_embed_provider().supports_embeddings() {
27            return Ok(());
28        }
29        let embedding = match tokio::time::timeout(
30            self.embed_timeout,
31            self.effective_embed_provider().embed(correction_text),
32        )
33        .await
34        {
35            Ok(Ok(v)) => v,
36            Ok(Err(e)) => return Err(MemoryError::Llm(e)),
37            Err(_) => {
38                tracing::warn!("corrections: embed timed out, skipping vector store write");
39                return Ok(());
40            }
41        };
42        let vector_size = u64::try_from(embedding.len()).unwrap_or(896);
43        store
44            .ensure_named_collection(CORRECTIONS_COLLECTION, vector_size)
45            .await?;
46        let payload = serde_json::json!({ "correction_id": correction_id });
47        store
48            .store_to_collection(CORRECTIONS_COLLECTION, payload, embedding)
49            .await?;
50        Ok(())
51    }
52
53    /// Retrieve corrections semantically similar to `query`.
54    ///
55    /// Returns up to `limit` corrections scoring above `min_score`.
56    /// Returns an empty vec if no vector store is configured.
57    ///
58    /// # Errors
59    ///
60    /// Returns an error if embedding generation or vector search fails.
61    pub async fn retrieve_similar_corrections(
62        &self,
63        query: &str,
64        limit: usize,
65        min_score: f32,
66    ) -> Result<Vec<crate::store::corrections::UserCorrectionRow>, MemoryError> {
67        let Some(ref store) = self.qdrant else {
68            tracing::debug!("corrections: skipped, no vector store");
69            return Ok(vec![]);
70        };
71        if !self.effective_embed_provider().supports_embeddings() {
72            tracing::debug!("corrections: skipped, no embedding support");
73            return Ok(vec![]);
74        }
75        let embedding = match tokio::time::timeout(
76            self.embed_timeout,
77            self.effective_embed_provider().embed(query),
78        )
79        .await
80        {
81            Ok(Ok(v)) => v,
82            Ok(Err(e)) => return Err(MemoryError::Llm(e)),
83            Err(_) => {
84                tracing::warn!("search_corrections: embed() timed out, returning empty");
85                return Ok(vec![]);
86            }
87        };
88        let vector_size = u64::try_from(embedding.len()).unwrap_or(896);
89        store
90            .ensure_named_collection(CORRECTIONS_COLLECTION, vector_size)
91            .await?;
92        let scored = store
93            .search_collection(CORRECTIONS_COLLECTION, &embedding, limit, None)
94            .await
95            .unwrap_or_default();
96
97        tracing::debug!(
98            candidates = scored.len(),
99            min_score = %min_score,
100            limit,
101            "corrections: search complete"
102        );
103
104        let mut results = Vec::new();
105        for point in scored {
106            if point.score < min_score {
107                continue;
108            }
109            if let Some(id_val) = point.payload.get("correction_id")
110                && let Some(id) = id_val.as_i64()
111            {
112                let rows = self.sqlite.load_corrections_for_id(id).await?;
113                results.extend(rows);
114            }
115        }
116
117        tracing::debug!(
118            retained = results.len(),
119            "corrections: after min_score filter"
120        );
121
122        Ok(results)
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use std::sync::Arc;
129
130    use zeph_llm::any::AnyProvider;
131    use zeph_llm::mock::MockProvider;
132
133    use crate::embedding_store::EmbeddingStore;
134    use crate::in_memory_store::InMemoryVectorStore;
135    use crate::semantic::SemanticMemory;
136    use crate::store::SqliteStore;
137    use crate::token_counter::TokenCounter;
138
139    async fn mem_with_slow_embed(embed_delay_ms: u64) -> SemanticMemory {
140        let sqlite = SqliteStore::new(":memory:").await.unwrap();
141        let pool = sqlite.pool().clone();
142        let qdrant = EmbeddingStore::with_store(Box::new(InMemoryVectorStore::new()), pool);
143        let base_provider = AnyProvider::Mock(MockProvider::default());
144        let slow_embed =
145            AnyProvider::Mock(MockProvider::default().with_embed_delay(embed_delay_ms));
146        SemanticMemory::from_parts(
147            sqlite,
148            Some(Arc::new(qdrant)),
149            base_provider,
150            "test-model",
151            0.7,
152            0.3,
153            Arc::new(TokenCounter::new()),
154        )
155        .with_embedding_provider(slow_embed)
156    }
157
158    /// `embed()` timeout in `store_correction_embedding` → returns `Ok(())` (fail-open, skips write).
159    #[tokio::test]
160    async fn store_correction_embedding_embed_timeout_is_ok() {
161        // Build memory before pausing time — SQLite pool uses tokio timers internally.
162        let mem = mem_with_slow_embed(10_000).await;
163
164        tokio::time::pause();
165
166        let fut = mem.store_correction_embedding(42, "I prefer detailed answers");
167        let (result, ()) = tokio::join!(fut, async {
168            tokio::time::advance(std::time::Duration::from_secs(6)).await;
169        });
170
171        assert!(
172            result.is_ok(),
173            "embed timeout must return Ok(()) (fail-open, skip write), got {result:?}"
174        );
175    }
176
177    /// `embed()` timeout in `retrieve_similar_corrections` → returns `Ok(vec![])` (fail-open).
178    #[tokio::test]
179    async fn retrieve_similar_corrections_embed_timeout_returns_empty() {
180        let mem = mem_with_slow_embed(10_000).await;
181
182        tokio::time::pause();
183
184        let fut = mem.retrieve_similar_corrections("prefer concise answers", 5, 0.7);
185        let (result, ()) = tokio::join!(fut, async {
186            tokio::time::advance(std::time::Duration::from_secs(6)).await;
187        });
188
189        match result {
190            Ok(rows) => assert!(
191                rows.is_empty(),
192                "embed timeout must return empty vec (fail-open), got {rows:?}"
193            ),
194            Err(e) => panic!("embed timeout must not propagate error, got {e:?}"),
195        }
196    }
197}