zeph_memory/semantic/
corrections.rs1use zeph_llm::provider::LlmProvider as _;
5
6use crate::error::MemoryError;
7
8use super::{CORRECTIONS_COLLECTION, SemanticMemory};
9
10impl SemanticMemory {
11 pub async fn store_correction_embedding(
19 &self,
20 correction_id: i64,
21 correction_text: &str,
22 ) -> Result<(), MemoryError> {
23 let Some(ref store) = self.qdrant else {
24 return Ok(());
25 };
26 if !self.provider.supports_embeddings() {
27 return Ok(());
28 }
29 let embedding = self
30 .provider
31 .embed(correction_text)
32 .await
33 .map_err(|e| MemoryError::Other(e.to_string()))?;
34 let vector_size = u64::try_from(embedding.len()).unwrap_or(896);
35 store
36 .ensure_named_collection(CORRECTIONS_COLLECTION, vector_size)
37 .await?;
38 let payload = serde_json::json!({ "correction_id": correction_id });
39 store
40 .store_to_collection(CORRECTIONS_COLLECTION, payload, embedding)
41 .await?;
42 Ok(())
43 }
44
45 pub async fn retrieve_similar_corrections(
54 &self,
55 query: &str,
56 limit: usize,
57 min_score: f32,
58 ) -> Result<Vec<crate::sqlite::corrections::UserCorrectionRow>, MemoryError> {
59 let Some(ref store) = self.qdrant else {
60 tracing::debug!("corrections: skipped, no vector store");
61 return Ok(vec![]);
62 };
63 if !self.provider.supports_embeddings() {
64 tracing::debug!("corrections: skipped, no embedding support");
65 return Ok(vec![]);
66 }
67 let embedding = self
68 .provider
69 .embed(query)
70 .await
71 .map_err(|e| MemoryError::Other(e.to_string()))?;
72 let vector_size = u64::try_from(embedding.len()).unwrap_or(896);
73 store
74 .ensure_named_collection(CORRECTIONS_COLLECTION, vector_size)
75 .await?;
76 let scored = store
77 .search_collection(CORRECTIONS_COLLECTION, &embedding, limit, None)
78 .await
79 .unwrap_or_default();
80
81 tracing::debug!(
82 candidates = scored.len(),
83 min_score = %min_score,
84 limit,
85 "corrections: search complete"
86 );
87
88 let mut results = Vec::new();
89 for point in scored {
90 if point.score < min_score {
91 continue;
92 }
93 if let Some(id_val) = point.payload.get("correction_id")
94 && let Some(id) = id_val.as_i64()
95 {
96 let rows = self.sqlite.load_corrections_for_id(id).await?;
97 results.extend(rows);
98 }
99 }
100
101 tracing::debug!(
102 retained = results.len(),
103 "corrections: after min_score filter"
104 );
105
106 Ok(results)
107 }
108}