Skip to main content

zeph_memory/semantic/
corrections.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::time::Duration;
5
6use zeph_llm::provider::LlmProvider as _;
7
8use crate::error::MemoryError;
9
10use super::{CORRECTIONS_COLLECTION, SemanticMemory};
11
12impl SemanticMemory {
13    /// Store an embedding for a user correction in the vector store.
14    ///
15    /// Silently skips if no vector store is configured or embeddings are unsupported.
16    ///
17    /// # Errors
18    ///
19    /// Returns an error if embedding generation or vector store write fails.
20    pub async fn store_correction_embedding(
21        &self,
22        correction_id: i64,
23        correction_text: &str,
24    ) -> Result<(), MemoryError> {
25        let Some(ref store) = self.qdrant else {
26            return Ok(());
27        };
28        if !self.effective_embed_provider().supports_embeddings() {
29            return Ok(());
30        }
31        let embedding = match tokio::time::timeout(
32            Duration::from_secs(5),
33            self.effective_embed_provider().embed(correction_text),
34        )
35        .await
36        {
37            Ok(Ok(v)) => v,
38            Ok(Err(e)) => return Err(MemoryError::Llm(e)),
39            Err(_) => {
40                tracing::warn!("corrections: embed timed out, skipping vector store write");
41                return Ok(());
42            }
43        };
44        let vector_size = u64::try_from(embedding.len()).unwrap_or(896);
45        store
46            .ensure_named_collection(CORRECTIONS_COLLECTION, vector_size)
47            .await?;
48        let payload = serde_json::json!({ "correction_id": correction_id });
49        store
50            .store_to_collection(CORRECTIONS_COLLECTION, payload, embedding)
51            .await?;
52        Ok(())
53    }
54
55    /// Retrieve corrections semantically similar to `query`.
56    ///
57    /// Returns up to `limit` corrections scoring above `min_score`.
58    /// Returns an empty vec if no vector store is configured.
59    ///
60    /// # Errors
61    ///
62    /// Returns an error if embedding generation or vector search fails.
63    pub async fn retrieve_similar_corrections(
64        &self,
65        query: &str,
66        limit: usize,
67        min_score: f32,
68    ) -> Result<Vec<crate::store::corrections::UserCorrectionRow>, MemoryError> {
69        let Some(ref store) = self.qdrant else {
70            tracing::debug!("corrections: skipped, no vector store");
71            return Ok(vec![]);
72        };
73        if !self.effective_embed_provider().supports_embeddings() {
74            tracing::debug!("corrections: skipped, no embedding support");
75            return Ok(vec![]);
76        }
77        let embedding = self
78            .effective_embed_provider()
79            .embed(query)
80            .await
81            .map_err(MemoryError::Llm)?;
82        let vector_size = u64::try_from(embedding.len()).unwrap_or(896);
83        store
84            .ensure_named_collection(CORRECTIONS_COLLECTION, vector_size)
85            .await?;
86        let scored = store
87            .search_collection(CORRECTIONS_COLLECTION, &embedding, limit, None)
88            .await
89            .unwrap_or_default();
90
91        tracing::debug!(
92            candidates = scored.len(),
93            min_score = %min_score,
94            limit,
95            "corrections: search complete"
96        );
97
98        let mut results = Vec::new();
99        for point in scored {
100            if point.score < min_score {
101                continue;
102            }
103            if let Some(id_val) = point.payload.get("correction_id")
104                && let Some(id) = id_val.as_i64()
105            {
106                let rows = self.sqlite.load_corrections_for_id(id).await?;
107                results.extend(rows);
108            }
109        }
110
111        tracing::debug!(
112            retained = results.len(),
113            "corrections: after min_score filter"
114        );
115
116        Ok(results)
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use std::sync::Arc;
123
124    use zeph_llm::any::AnyProvider;
125    use zeph_llm::mock::MockProvider;
126
127    use crate::embedding_store::EmbeddingStore;
128    use crate::in_memory_store::InMemoryVectorStore;
129    use crate::semantic::SemanticMemory;
130    use crate::store::SqliteStore;
131    use crate::token_counter::TokenCounter;
132
133    async fn mem_with_slow_embed(embed_delay_ms: u64) -> SemanticMemory {
134        let sqlite = SqliteStore::new(":memory:").await.unwrap();
135        let pool = sqlite.pool().clone();
136        let qdrant = EmbeddingStore::with_store(Box::new(InMemoryVectorStore::new()), pool);
137        let base_provider = AnyProvider::Mock(MockProvider::default());
138        let slow_embed =
139            AnyProvider::Mock(MockProvider::default().with_embed_delay(embed_delay_ms));
140        SemanticMemory::from_parts(
141            sqlite,
142            Some(Arc::new(qdrant)),
143            base_provider,
144            "test-model",
145            0.7,
146            0.3,
147            Arc::new(TokenCounter::new()),
148        )
149        .with_embed_provider(slow_embed)
150    }
151
152    /// `embed()` timeout in `store_correction_embedding` → returns `Ok(())` (fail-open, skips write).
153    #[tokio::test]
154    async fn store_correction_embedding_embed_timeout_is_ok() {
155        // Build memory before pausing time — SQLite pool uses tokio timers internally.
156        let mem = mem_with_slow_embed(10_000).await;
157
158        tokio::time::pause();
159
160        let fut = mem.store_correction_embedding(42, "I prefer detailed answers");
161        let (result, ()) = tokio::join!(fut, async {
162            tokio::time::advance(std::time::Duration::from_secs(6)).await;
163        });
164
165        assert!(
166            result.is_ok(),
167            "embed timeout must return Ok(()) (fail-open, skip write), got {result:?}"
168        );
169    }
170}