swiftide_integrations/qdrant/
mod.rs1mod indexing_node;
8mod persist;
9mod retrieve;
10use std::collections::{HashMap, HashSet};
11
12use std::sync::Arc;
13
14use anyhow::{bail, Context as _, Result};
15use derive_builder::Builder;
16use qdrant_client::qdrant::{self, SparseVectorParamsBuilder, SparseVectorsConfigBuilder};
17
18use swiftide_core::indexing::{EmbeddedField, Node};
19
20const DEFAULT_COLLECTION_NAME: &str = "swiftide";
21const DEFAULT_QDRANT_URL: &str = "http://localhost:6334";
22const DEFAULT_BATCH_SIZE: usize = 50;
23
24#[derive(Builder, Clone)]
31#[builder(
32 pattern = "owned",
33 setter(strip_option),
34 build_fn(error = "anyhow::Error")
35)]
36pub struct Qdrant {
37 #[builder(setter(into), default = "self.default_client()?")]
42 #[allow(clippy::missing_fields_in_debug)]
43 client: Arc<qdrant_client::Qdrant>,
44 #[builder(default = "DEFAULT_COLLECTION_NAME.to_string()")]
46 #[builder(setter(into))]
47 collection_name: String,
48 vector_size: u64,
50 #[builder(default = "Distance::Cosine")]
51 vector_distance: Distance,
53 #[builder(default = "Some(DEFAULT_BATCH_SIZE)")]
55 batch_size: Option<usize>,
56 #[builder(private, default = "Self::default_vectors()")]
57 pub(crate) vectors: HashMap<EmbeddedField, VectorConfig>,
58 #[builder(private, default)]
59 pub(crate) sparse_vectors: HashMap<EmbeddedField, SparseVectorConfig>,
60}
61
62impl Qdrant {
63 pub fn builder() -> QdrantBuilder {
65 QdrantBuilder::default()
66 }
67
68 pub fn try_from_url(url: impl AsRef<str>) -> Result<QdrantBuilder> {
85 Ok(QdrantBuilder::default().client(
86 qdrant_client::Qdrant::from_url(url.as_ref())
87 .api_key(std::env::var("QDRANT_API_KEY"))
88 .build()?,
89 ))
90 }
91
92 pub async fn create_index_if_not_exists(&self) -> Result<()> {
105 let collection_name = &self.collection_name;
106
107 tracing::debug!("Checking if collection {collection_name} exists");
108 if self.client.collection_exists(collection_name).await? {
109 tracing::debug!("Collection {collection_name} exists");
110 return Ok(());
111 }
112
113 let vectors_config = self.create_vectors_config()?;
114 tracing::debug!(?vectors_config, "Adding vectors config");
115
116 let mut collection =
117 qdrant::CreateCollectionBuilder::new(collection_name).vectors_config(vectors_config);
118
119 if let Some(sparse_vectors_config) = self.create_sparse_vectors_config() {
120 tracing::debug!(?sparse_vectors_config, "Adding sparse vectors config");
121 collection = collection.sparse_vectors_config(sparse_vectors_config);
122 }
123
124 tracing::info!("Creating collection {collection_name}");
125 self.client.create_collection(collection).await?;
126 Ok(())
127 }
128
129 fn create_vectors_config(&self) -> Result<qdrant_client::qdrant::vectors_config::Config> {
130 if self.vectors.is_empty() {
131 bail!("No configured vectors");
132 } else if self.vectors.len() == 1 && self.sparse_vectors.is_empty() {
133 let config = self
134 .vectors
135 .values()
136 .next()
137 .context("Has one vector config")?;
138 let vector_params = self.create_vector_params(config);
139 return Ok(qdrant::vectors_config::Config::Params(vector_params));
140 }
141 let mut map = HashMap::<String, qdrant::VectorParams>::default();
142 for (embedded_field, config) in &self.vectors {
143 let vector_name = embedded_field.to_string();
144 let vector_params = self.create_vector_params(config);
145
146 map.insert(vector_name, vector_params);
147 }
148
149 Ok(qdrant::vectors_config::Config::ParamsMap(
150 qdrant::VectorParamsMap { map },
151 ))
152 }
153
154 fn create_sparse_vectors_config(&self) -> Option<qdrant::SparseVectorConfig> {
155 if self.sparse_vectors.is_empty() {
156 return None;
157 }
158 let mut sparse_vectors_config = SparseVectorsConfigBuilder::default();
159 for embedded_field in self.sparse_vectors.keys() {
160 let vector_name = format!("{embedded_field}_sparse");
161 let vector_params = SparseVectorParamsBuilder::default();
162 sparse_vectors_config.add_named_vector_params(vector_name, vector_params);
163 }
164
165 Some(sparse_vectors_config.into())
166 }
167
168 fn create_vector_params(&self, config: &VectorConfig) -> qdrant::VectorParams {
169 let size = config.vector_size.unwrap_or(self.vector_size);
170 let distance = config.distance.unwrap_or(self.vector_distance);
171
172 tracing::debug!(
173 "Creating vector params: size={}, distance={:?}",
174 size,
175 distance
176 );
177 qdrant::VectorParamsBuilder::new(size, distance).build()
178 }
179
180 pub fn client(&self) -> &Arc<qdrant_client::Qdrant> {
182 &self.client
183 }
184}
185
186impl QdrantBuilder {
187 #[allow(clippy::unused_self)]
188 fn default_client(&self) -> Result<Arc<qdrant_client::Qdrant>> {
189 let client = qdrant_client::Qdrant::from_url(
190 &std::env::var("QDRANT_URL").unwrap_or(DEFAULT_QDRANT_URL.to_string()),
191 )
192 .api_key(std::env::var("QDRANT_API_KEY"))
193 .build()
194 .context("Could not build default qdrant client")?;
195
196 Ok(Arc::new(client))
197 }
198
199 #[must_use]
206 pub fn with_vector(mut self, vector: impl Into<VectorConfig>) -> QdrantBuilder {
207 if self.vectors.is_none() {
208 self = self.vectors(HashMap::default());
209 }
210 let vector = vector.into();
211 if let Some(vectors) = self.vectors.as_mut() {
212 if let Some(overridden_vector) = vectors.insert(vector.embedded_field.clone(), vector) {
213 tracing::warn!(
214 "Overriding named vector config: {}",
215 overridden_vector.embedded_field
216 );
217 }
218 }
219 self
220 }
221
222 #[must_use]
224 pub fn with_sparse_vector(mut self, vector: impl Into<SparseVectorConfig>) -> QdrantBuilder {
225 if self.sparse_vectors.is_none() {
226 self = self.sparse_vectors(HashMap::default());
227 }
228 let vector = vector.into();
229 if let Some(vectors) = self.sparse_vectors.as_mut() {
230 if let Some(overridden_vector) = vectors.insert(vector.embedded_field.clone(), vector) {
231 tracing::warn!(
232 "Overriding named vector config: {}",
233 overridden_vector.embedded_field
234 );
235 }
236 }
237 self
238 }
239
240 fn default_vectors() -> HashMap<EmbeddedField, VectorConfig> {
241 HashMap::from([(EmbeddedField::default(), VectorConfig::default())])
242 }
243}
244
245#[allow(clippy::missing_fields_in_debug)]
246impl std::fmt::Debug for Qdrant {
247 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248 f.debug_struct("Qdrant")
249 .field("collection_name", &self.collection_name)
250 .field("vector_size", &self.vector_size)
251 .field("batch_size", &self.batch_size)
252 .finish()
253 }
254}
255
256#[derive(Clone, Builder, Default)]
260pub struct VectorConfig {
261 #[builder(default)]
263 pub(super) embedded_field: EmbeddedField,
264 #[builder(setter(into, strip_option), default)]
268 vector_size: Option<u64>,
269 #[builder(setter(into, strip_option), default)]
273 distance: Option<qdrant::Distance>,
274}
275
276impl VectorConfig {
277 pub fn builder() -> VectorConfigBuilder {
278 VectorConfigBuilder::default()
279 }
280}
281
282impl From<EmbeddedField> for VectorConfig {
283 fn from(value: EmbeddedField) -> Self {
284 Self {
285 embedded_field: value,
286 ..Default::default()
287 }
288 }
289}
290
291#[derive(Clone, Builder, Default)]
293pub struct SparseVectorConfig {
294 embedded_field: EmbeddedField,
295}
296
297impl From<EmbeddedField> for SparseVectorConfig {
298 fn from(value: EmbeddedField) -> Self {
299 Self {
300 embedded_field: value,
301 }
302 }
303}
304
305pub type Distance = qdrant::Distance;
306
307struct NodeWithVectors<'a> {
309 node: &'a Node,
310 vector_fields: HashSet<&'a EmbeddedField>,
311}
312
313impl<'a> NodeWithVectors<'a> {
314 pub fn new(node: &'a Node, vector_fields: HashSet<&'a EmbeddedField>) -> Self {
315 Self {
316 node,
317 vector_fields,
318 }
319 }
320}