Skip to main content

zeph_memory/
qdrant_ops.rs

1//! Low-level Qdrant operations shared across crates.
2
3use std::collections::HashMap;
4
5use qdrant_client::Qdrant;
6use qdrant_client::qdrant::{
7    CreateCollectionBuilder, DeletePointsBuilder, Distance, Filter, PointId, PointStruct,
8    PointsIdsList, ScoredPoint, ScrollPointsBuilder, SearchPointsBuilder, UpsertPointsBuilder,
9    VectorParamsBuilder, value::Kind,
10};
11
12type QdrantResult<T> = Result<T, Box<qdrant_client::QdrantError>>;
13
14/// Thin wrapper over [`Qdrant`] client encapsulating common collection operations.
15#[derive(Clone)]
16pub struct QdrantOps {
17    client: Qdrant,
18}
19
20impl std::fmt::Debug for QdrantOps {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        f.debug_struct("QdrantOps").finish_non_exhaustive()
23    }
24}
25
26impl QdrantOps {
27    /// Create a new `QdrantOps` connected to the given URL.
28    ///
29    /// # Errors
30    ///
31    /// Returns an error if the Qdrant client cannot be created.
32    pub fn new(url: &str) -> QdrantResult<Self> {
33        let client = Qdrant::from_url(url).build().map_err(Box::new)?;
34        Ok(Self { client })
35    }
36
37    /// Access the underlying Qdrant client for advanced operations.
38    #[must_use]
39    pub fn client(&self) -> &Qdrant {
40        &self.client
41    }
42
43    /// Ensure a collection exists with cosine distance vectors.
44    ///
45    /// Idempotent: no-op if the collection already exists.
46    ///
47    /// # Errors
48    ///
49    /// Returns an error if Qdrant cannot be reached or collection creation fails.
50    pub async fn ensure_collection(&self, collection: &str, vector_size: u64) -> QdrantResult<()> {
51        if self
52            .client
53            .collection_exists(collection)
54            .await
55            .map_err(Box::new)?
56        {
57            return Ok(());
58        }
59        self.client
60            .create_collection(
61                CreateCollectionBuilder::new(collection)
62                    .vectors_config(VectorParamsBuilder::new(vector_size, Distance::Cosine)),
63            )
64            .await
65            .map_err(Box::new)?;
66        Ok(())
67    }
68
69    /// Check whether a collection exists.
70    ///
71    /// # Errors
72    ///
73    /// Returns an error if Qdrant cannot be reached.
74    pub async fn collection_exists(&self, collection: &str) -> QdrantResult<bool> {
75        self.client
76            .collection_exists(collection)
77            .await
78            .map_err(Box::new)
79    }
80
81    /// Delete a collection.
82    ///
83    /// # Errors
84    ///
85    /// Returns an error if the collection cannot be deleted.
86    pub async fn delete_collection(&self, collection: &str) -> QdrantResult<()> {
87        self.client
88            .delete_collection(collection)
89            .await
90            .map_err(Box::new)?;
91        Ok(())
92    }
93
94    /// Upsert points into a collection.
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if the upsert fails.
99    pub async fn upsert(&self, collection: &str, points: Vec<PointStruct>) -> QdrantResult<()> {
100        self.client
101            .upsert_points(UpsertPointsBuilder::new(collection, points).wait(true))
102            .await
103            .map_err(Box::new)?;
104        Ok(())
105    }
106
107    /// Search for similar vectors, returning scored points with payloads.
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if the search fails.
112    pub async fn search(
113        &self,
114        collection: &str,
115        vector: Vec<f32>,
116        limit: u64,
117        filter: Option<Filter>,
118    ) -> QdrantResult<Vec<ScoredPoint>> {
119        let mut builder = SearchPointsBuilder::new(collection, vector, limit).with_payload(true);
120        if let Some(f) = filter {
121            builder = builder.filter(f);
122        }
123        let results = self.client.search_points(builder).await.map_err(Box::new)?;
124        Ok(results.result)
125    }
126
127    /// Delete points by their IDs.
128    ///
129    /// # Errors
130    ///
131    /// Returns an error if the deletion fails.
132    pub async fn delete_by_ids(&self, collection: &str, ids: Vec<PointId>) -> QdrantResult<()> {
133        if ids.is_empty() {
134            return Ok(());
135        }
136        self.client
137            .delete_points(
138                DeletePointsBuilder::new(collection)
139                    .points(PointsIdsList { ids })
140                    .wait(true),
141            )
142            .await
143            .map_err(Box::new)?;
144        Ok(())
145    }
146
147    /// Scroll all points in a collection, extracting string payload fields.
148    ///
149    /// Returns a map of `key_field` value -> { `field_name` -> `field_value` }.
150    ///
151    /// # Errors
152    ///
153    /// Returns an error if the scroll operation fails.
154    pub async fn scroll_all(
155        &self,
156        collection: &str,
157        key_field: &str,
158    ) -> QdrantResult<HashMap<String, HashMap<String, String>>> {
159        let mut result = HashMap::new();
160        let mut offset: Option<PointId> = None;
161
162        loop {
163            let mut builder = ScrollPointsBuilder::new(collection)
164                .with_payload(true)
165                .with_vectors(false)
166                .limit(100);
167
168            if let Some(ref off) = offset {
169                builder = builder.offset(off.clone());
170            }
171
172            let response = self.client.scroll(builder).await.map_err(Box::new)?;
173
174            for point in &response.result {
175                let Some(key_val) = point.payload.get(key_field) else {
176                    continue;
177                };
178                let Some(Kind::StringValue(key)) = &key_val.kind else {
179                    continue;
180                };
181
182                let mut fields = HashMap::new();
183                for (k, val) in &point.payload {
184                    if let Some(Kind::StringValue(s)) = &val.kind {
185                        fields.insert(k.clone(), s.clone());
186                    }
187                }
188                result.insert(key.clone(), fields);
189            }
190
191            match response.next_page_offset {
192                Some(next) => offset = Some(next),
193                None => break,
194            }
195        }
196
197        Ok(result)
198    }
199
200    /// Convert a JSON value to a Qdrant payload map.
201    ///
202    /// # Errors
203    ///
204    /// Returns a JSON error if deserialization fails.
205    pub fn json_to_payload(
206        value: serde_json::Value,
207    ) -> Result<HashMap<String, qdrant_client::qdrant::Value>, serde_json::Error> {
208        serde_json::from_value(value)
209    }
210}
211
212impl crate::vector_store::VectorStore for QdrantOps {
213    fn ensure_collection(
214        &self,
215        collection: &str,
216        vector_size: u64,
217    ) -> std::pin::Pin<
218        Box<dyn std::future::Future<Output = Result<(), crate::VectorStoreError>> + Send + '_>,
219    > {
220        let collection = collection.to_owned();
221        Box::pin(async move {
222            self.ensure_collection(&collection, vector_size)
223                .await
224                .map_err(|e| crate::VectorStoreError::Collection(e.to_string()))
225        })
226    }
227
228    fn collection_exists(
229        &self,
230        collection: &str,
231    ) -> std::pin::Pin<
232        Box<dyn std::future::Future<Output = Result<bool, crate::VectorStoreError>> + Send + '_>,
233    > {
234        let collection = collection.to_owned();
235        Box::pin(async move {
236            self.collection_exists(&collection)
237                .await
238                .map_err(|e| crate::VectorStoreError::Collection(e.to_string()))
239        })
240    }
241
242    fn delete_collection(
243        &self,
244        collection: &str,
245    ) -> std::pin::Pin<
246        Box<dyn std::future::Future<Output = Result<(), crate::VectorStoreError>> + Send + '_>,
247    > {
248        let collection = collection.to_owned();
249        Box::pin(async move {
250            self.delete_collection(&collection)
251                .await
252                .map_err(|e| crate::VectorStoreError::Collection(e.to_string()))
253        })
254    }
255
256    fn upsert(
257        &self,
258        collection: &str,
259        points: Vec<crate::VectorPoint>,
260    ) -> std::pin::Pin<
261        Box<dyn std::future::Future<Output = Result<(), crate::VectorStoreError>> + Send + '_>,
262    > {
263        let collection = collection.to_owned();
264        Box::pin(async move {
265            let qdrant_points: Vec<PointStruct> = points
266                .into_iter()
267                .map(|p| {
268                    let payload: HashMap<String, qdrant_client::qdrant::Value> =
269                        serde_json::from_value(serde_json::Value::Object(
270                            p.payload.into_iter().collect(),
271                        ))
272                        .unwrap_or_default();
273                    PointStruct::new(p.id, p.vector, payload)
274                })
275                .collect();
276            self.upsert(&collection, qdrant_points)
277                .await
278                .map_err(|e| crate::VectorStoreError::Upsert(e.to_string()))
279        })
280    }
281
282    fn search(
283        &self,
284        collection: &str,
285        vector: Vec<f32>,
286        limit: u64,
287        filter: Option<crate::VectorFilter>,
288    ) -> std::pin::Pin<
289        Box<
290            dyn std::future::Future<
291                    Output = Result<Vec<crate::ScoredVectorPoint>, crate::VectorStoreError>,
292                > + Send
293                + '_,
294        >,
295    > {
296        let collection = collection.to_owned();
297        Box::pin(async move {
298            let qdrant_filter = filter.map(vector_filter_to_qdrant);
299            let results = self
300                .search(&collection, vector, limit, qdrant_filter)
301                .await
302                .map_err(|e| crate::VectorStoreError::Search(e.to_string()))?;
303            Ok(results.into_iter().map(scored_point_to_vector).collect())
304        })
305    }
306
307    fn delete_by_ids(
308        &self,
309        collection: &str,
310        ids: Vec<String>,
311    ) -> std::pin::Pin<
312        Box<dyn std::future::Future<Output = Result<(), crate::VectorStoreError>> + Send + '_>,
313    > {
314        let collection = collection.to_owned();
315        Box::pin(async move {
316            let point_ids: Vec<PointId> = ids.into_iter().map(PointId::from).collect();
317            self.delete_by_ids(&collection, point_ids)
318                .await
319                .map_err(|e| crate::VectorStoreError::Delete(e.to_string()))
320        })
321    }
322
323    fn scroll_all(
324        &self,
325        collection: &str,
326        key_field: &str,
327    ) -> std::pin::Pin<
328        Box<
329            dyn std::future::Future<
330                    Output = Result<
331                        HashMap<String, HashMap<String, String>>,
332                        crate::VectorStoreError,
333                    >,
334                > + Send
335                + '_,
336        >,
337    > {
338        let collection = collection.to_owned();
339        let key_field = key_field.to_owned();
340        Box::pin(async move {
341            self.scroll_all(&collection, &key_field)
342                .await
343                .map_err(|e| crate::VectorStoreError::Scroll(e.to_string()))
344        })
345    }
346
347    fn health_check(
348        &self,
349    ) -> std::pin::Pin<
350        Box<dyn std::future::Future<Output = Result<bool, crate::VectorStoreError>> + Send + '_>,
351    > {
352        Box::pin(async move {
353            self.client
354                .health_check()
355                .await
356                .map(|_| true)
357                .map_err(|e| crate::VectorStoreError::Collection(e.to_string()))
358        })
359    }
360}
361
362fn vector_filter_to_qdrant(filter: crate::VectorFilter) -> Filter {
363    let must: Vec<_> = filter
364        .must
365        .into_iter()
366        .map(field_condition_to_qdrant)
367        .collect();
368    let must_not: Vec<_> = filter
369        .must_not
370        .into_iter()
371        .map(field_condition_to_qdrant)
372        .collect();
373
374    let mut f = Filter::default();
375    if !must.is_empty() {
376        f.must = must;
377    }
378    if !must_not.is_empty() {
379        f.must_not = must_not;
380    }
381    f
382}
383
384fn field_condition_to_qdrant(cond: crate::FieldCondition) -> qdrant_client::qdrant::Condition {
385    match cond.value {
386        crate::FieldValue::Integer(v) => qdrant_client::qdrant::Condition::matches(cond.field, v),
387        crate::FieldValue::Text(v) => qdrant_client::qdrant::Condition::matches(cond.field, v),
388    }
389}
390
391fn scored_point_to_vector(point: ScoredPoint) -> crate::ScoredVectorPoint {
392    let payload: HashMap<String, serde_json::Value> = point
393        .payload
394        .into_iter()
395        .filter_map(|(k, v)| {
396            let json_val = match v.kind? {
397                Kind::StringValue(s) => serde_json::Value::String(s),
398                Kind::IntegerValue(i) => serde_json::Value::Number(i.into()),
399                Kind::DoubleValue(d) => {
400                    serde_json::Number::from_f64(d).map(serde_json::Value::Number)?
401                }
402                Kind::BoolValue(b) => serde_json::Value::Bool(b),
403                _ => return None,
404            };
405            Some((k, json_val))
406        })
407        .collect();
408
409    let id = match point.id.and_then(|pid| pid.point_id_options) {
410        Some(qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u)) => u,
411        Some(qdrant_client::qdrant::point_id::PointIdOptions::Num(n)) => n.to_string(),
412        None => String::new(),
413    };
414
415    crate::ScoredVectorPoint {
416        id,
417        score: point.score,
418        payload,
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    #[test]
427    fn new_valid_url() {
428        let ops = QdrantOps::new("http://localhost:6334");
429        assert!(ops.is_ok());
430    }
431
432    #[test]
433    fn new_invalid_url() {
434        let ops = QdrantOps::new("not a valid url");
435        assert!(ops.is_err());
436    }
437
438    #[test]
439    fn debug_format() {
440        let ops = QdrantOps::new("http://localhost:6334").unwrap();
441        let dbg = format!("{ops:?}");
442        assert!(dbg.contains("QdrantOps"));
443    }
444
445    #[test]
446    fn json_to_payload_valid() {
447        let value = serde_json::json!({"key": "value", "num": 42});
448        let result = QdrantOps::json_to_payload(value);
449        assert!(result.is_ok());
450    }
451
452    #[test]
453    fn json_to_payload_empty() {
454        let result = QdrantOps::json_to_payload(serde_json::json!({}));
455        assert!(result.is_ok());
456        assert!(result.unwrap().is_empty());
457    }
458}