1use std::collections::HashMap;
4
5use qdrant_client::Qdrant;
6use qdrant_client::qdrant::{
7 CreateCollectionBuilder, DeletePointsBuilder, Distance, Filter, PointId, PointStruct,
8 PointsIdsList, ScoredPoint, ScrollPointsBuilder, SearchPointsBuilder, UpsertPointsBuilder,
9 VectorParamsBuilder, value::Kind,
10};
11
12type QdrantResult<T> = Result<T, Box<qdrant_client::QdrantError>>;
13
14#[derive(Clone)]
16pub struct QdrantOps {
17 client: Qdrant,
18}
19
20impl std::fmt::Debug for QdrantOps {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 f.debug_struct("QdrantOps").finish_non_exhaustive()
23 }
24}
25
26impl QdrantOps {
27 pub fn new(url: &str) -> QdrantResult<Self> {
33 let client = Qdrant::from_url(url).build().map_err(Box::new)?;
34 Ok(Self { client })
35 }
36
37 #[must_use]
39 pub fn client(&self) -> &Qdrant {
40 &self.client
41 }
42
43 pub async fn ensure_collection(&self, collection: &str, vector_size: u64) -> QdrantResult<()> {
51 if self
52 .client
53 .collection_exists(collection)
54 .await
55 .map_err(Box::new)?
56 {
57 return Ok(());
58 }
59 self.client
60 .create_collection(
61 CreateCollectionBuilder::new(collection)
62 .vectors_config(VectorParamsBuilder::new(vector_size, Distance::Cosine)),
63 )
64 .await
65 .map_err(Box::new)?;
66 Ok(())
67 }
68
69 pub async fn collection_exists(&self, collection: &str) -> QdrantResult<bool> {
75 self.client
76 .collection_exists(collection)
77 .await
78 .map_err(Box::new)
79 }
80
81 pub async fn delete_collection(&self, collection: &str) -> QdrantResult<()> {
87 self.client
88 .delete_collection(collection)
89 .await
90 .map_err(Box::new)?;
91 Ok(())
92 }
93
94 pub async fn upsert(&self, collection: &str, points: Vec<PointStruct>) -> QdrantResult<()> {
100 self.client
101 .upsert_points(UpsertPointsBuilder::new(collection, points).wait(true))
102 .await
103 .map_err(Box::new)?;
104 Ok(())
105 }
106
107 pub async fn search(
113 &self,
114 collection: &str,
115 vector: Vec<f32>,
116 limit: u64,
117 filter: Option<Filter>,
118 ) -> QdrantResult<Vec<ScoredPoint>> {
119 let mut builder = SearchPointsBuilder::new(collection, vector, limit).with_payload(true);
120 if let Some(f) = filter {
121 builder = builder.filter(f);
122 }
123 let results = self.client.search_points(builder).await.map_err(Box::new)?;
124 Ok(results.result)
125 }
126
127 pub async fn delete_by_ids(&self, collection: &str, ids: Vec<PointId>) -> QdrantResult<()> {
133 if ids.is_empty() {
134 return Ok(());
135 }
136 self.client
137 .delete_points(
138 DeletePointsBuilder::new(collection)
139 .points(PointsIdsList { ids })
140 .wait(true),
141 )
142 .await
143 .map_err(Box::new)?;
144 Ok(())
145 }
146
147 pub async fn scroll_all(
155 &self,
156 collection: &str,
157 key_field: &str,
158 ) -> QdrantResult<HashMap<String, HashMap<String, String>>> {
159 let mut result = HashMap::new();
160 let mut offset: Option<PointId> = None;
161
162 loop {
163 let mut builder = ScrollPointsBuilder::new(collection)
164 .with_payload(true)
165 .with_vectors(false)
166 .limit(100);
167
168 if let Some(ref off) = offset {
169 builder = builder.offset(off.clone());
170 }
171
172 let response = self.client.scroll(builder).await.map_err(Box::new)?;
173
174 for point in &response.result {
175 let Some(key_val) = point.payload.get(key_field) else {
176 continue;
177 };
178 let Some(Kind::StringValue(key)) = &key_val.kind else {
179 continue;
180 };
181
182 let mut fields = HashMap::new();
183 for (k, val) in &point.payload {
184 if let Some(Kind::StringValue(s)) = &val.kind {
185 fields.insert(k.clone(), s.clone());
186 }
187 }
188 result.insert(key.clone(), fields);
189 }
190
191 match response.next_page_offset {
192 Some(next) => offset = Some(next),
193 None => break,
194 }
195 }
196
197 Ok(result)
198 }
199
200 pub fn json_to_payload(
206 value: serde_json::Value,
207 ) -> Result<HashMap<String, qdrant_client::qdrant::Value>, serde_json::Error> {
208 serde_json::from_value(value)
209 }
210}
211
212impl crate::vector_store::VectorStore for QdrantOps {
213 fn ensure_collection(
214 &self,
215 collection: &str,
216 vector_size: u64,
217 ) -> std::pin::Pin<
218 Box<dyn std::future::Future<Output = Result<(), crate::VectorStoreError>> + Send + '_>,
219 > {
220 let collection = collection.to_owned();
221 Box::pin(async move {
222 self.ensure_collection(&collection, vector_size)
223 .await
224 .map_err(|e| crate::VectorStoreError::Collection(e.to_string()))
225 })
226 }
227
228 fn collection_exists(
229 &self,
230 collection: &str,
231 ) -> std::pin::Pin<
232 Box<dyn std::future::Future<Output = Result<bool, crate::VectorStoreError>> + Send + '_>,
233 > {
234 let collection = collection.to_owned();
235 Box::pin(async move {
236 self.collection_exists(&collection)
237 .await
238 .map_err(|e| crate::VectorStoreError::Collection(e.to_string()))
239 })
240 }
241
242 fn delete_collection(
243 &self,
244 collection: &str,
245 ) -> std::pin::Pin<
246 Box<dyn std::future::Future<Output = Result<(), crate::VectorStoreError>> + Send + '_>,
247 > {
248 let collection = collection.to_owned();
249 Box::pin(async move {
250 self.delete_collection(&collection)
251 .await
252 .map_err(|e| crate::VectorStoreError::Collection(e.to_string()))
253 })
254 }
255
256 fn upsert(
257 &self,
258 collection: &str,
259 points: Vec<crate::VectorPoint>,
260 ) -> std::pin::Pin<
261 Box<dyn std::future::Future<Output = Result<(), crate::VectorStoreError>> + Send + '_>,
262 > {
263 let collection = collection.to_owned();
264 Box::pin(async move {
265 let qdrant_points: Vec<PointStruct> = points
266 .into_iter()
267 .map(|p| {
268 let payload: HashMap<String, qdrant_client::qdrant::Value> =
269 serde_json::from_value(serde_json::Value::Object(
270 p.payload.into_iter().collect(),
271 ))
272 .unwrap_or_default();
273 PointStruct::new(p.id, p.vector, payload)
274 })
275 .collect();
276 self.upsert(&collection, qdrant_points)
277 .await
278 .map_err(|e| crate::VectorStoreError::Upsert(e.to_string()))
279 })
280 }
281
282 fn search(
283 &self,
284 collection: &str,
285 vector: Vec<f32>,
286 limit: u64,
287 filter: Option<crate::VectorFilter>,
288 ) -> std::pin::Pin<
289 Box<
290 dyn std::future::Future<
291 Output = Result<Vec<crate::ScoredVectorPoint>, crate::VectorStoreError>,
292 > + Send
293 + '_,
294 >,
295 > {
296 let collection = collection.to_owned();
297 Box::pin(async move {
298 let qdrant_filter = filter.map(vector_filter_to_qdrant);
299 let results = self
300 .search(&collection, vector, limit, qdrant_filter)
301 .await
302 .map_err(|e| crate::VectorStoreError::Search(e.to_string()))?;
303 Ok(results.into_iter().map(scored_point_to_vector).collect())
304 })
305 }
306
307 fn delete_by_ids(
308 &self,
309 collection: &str,
310 ids: Vec<String>,
311 ) -> std::pin::Pin<
312 Box<dyn std::future::Future<Output = Result<(), crate::VectorStoreError>> + Send + '_>,
313 > {
314 let collection = collection.to_owned();
315 Box::pin(async move {
316 let point_ids: Vec<PointId> = ids.into_iter().map(PointId::from).collect();
317 self.delete_by_ids(&collection, point_ids)
318 .await
319 .map_err(|e| crate::VectorStoreError::Delete(e.to_string()))
320 })
321 }
322
323 fn scroll_all(
324 &self,
325 collection: &str,
326 key_field: &str,
327 ) -> std::pin::Pin<
328 Box<
329 dyn std::future::Future<
330 Output = Result<
331 HashMap<String, HashMap<String, String>>,
332 crate::VectorStoreError,
333 >,
334 > + Send
335 + '_,
336 >,
337 > {
338 let collection = collection.to_owned();
339 let key_field = key_field.to_owned();
340 Box::pin(async move {
341 self.scroll_all(&collection, &key_field)
342 .await
343 .map_err(|e| crate::VectorStoreError::Scroll(e.to_string()))
344 })
345 }
346}
347
348fn vector_filter_to_qdrant(filter: crate::VectorFilter) -> Filter {
349 let must: Vec<_> = filter
350 .must
351 .into_iter()
352 .map(field_condition_to_qdrant)
353 .collect();
354 let must_not: Vec<_> = filter
355 .must_not
356 .into_iter()
357 .map(field_condition_to_qdrant)
358 .collect();
359
360 let mut f = Filter::default();
361 if !must.is_empty() {
362 f.must = must;
363 }
364 if !must_not.is_empty() {
365 f.must_not = must_not;
366 }
367 f
368}
369
370fn field_condition_to_qdrant(cond: crate::FieldCondition) -> qdrant_client::qdrant::Condition {
371 match cond.value {
372 crate::FieldValue::Integer(v) => qdrant_client::qdrant::Condition::matches(cond.field, v),
373 crate::FieldValue::Text(v) => qdrant_client::qdrant::Condition::matches(cond.field, v),
374 }
375}
376
377fn scored_point_to_vector(point: ScoredPoint) -> crate::ScoredVectorPoint {
378 let payload: HashMap<String, serde_json::Value> = point
379 .payload
380 .into_iter()
381 .filter_map(|(k, v)| {
382 let json_val = match v.kind? {
383 Kind::StringValue(s) => serde_json::Value::String(s),
384 Kind::IntegerValue(i) => serde_json::Value::Number(i.into()),
385 Kind::DoubleValue(d) => {
386 serde_json::Number::from_f64(d).map(serde_json::Value::Number)?
387 }
388 Kind::BoolValue(b) => serde_json::Value::Bool(b),
389 _ => return None,
390 };
391 Some((k, json_val))
392 })
393 .collect();
394
395 let id = match point.id.and_then(|pid| pid.point_id_options) {
396 Some(qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u)) => u,
397 Some(qdrant_client::qdrant::point_id::PointIdOptions::Num(n)) => n.to_string(),
398 None => String::new(),
399 };
400
401 crate::ScoredVectorPoint {
402 id,
403 score: point.score,
404 payload,
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
413 fn new_valid_url() {
414 let ops = QdrantOps::new("http://localhost:6334");
415 assert!(ops.is_ok());
416 }
417
418 #[test]
419 fn new_invalid_url() {
420 let ops = QdrantOps::new("not a valid url");
421 assert!(ops.is_err());
422 }
423
424 #[test]
425 fn debug_format() {
426 let ops = QdrantOps::new("http://localhost:6334").unwrap();
427 let dbg = format!("{ops:?}");
428 assert!(dbg.contains("QdrantOps"));
429 }
430
431 #[test]
432 fn json_to_payload_valid() {
433 let value = serde_json::json!({"key": "value", "num": 42});
434 let result = QdrantOps::json_to_payload(value);
435 assert!(result.is_ok());
436 }
437
438 #[test]
439 fn json_to_payload_empty() {
440 let result = QdrantOps::json_to_payload(serde_json::json!({}));
441 assert!(result.is_ok());
442 assert!(result.unwrap().is_empty());
443 }
444}