swiftide_integrations/qdrant/
mod.rs1mod indexing_node;
8mod persist;
9mod retrieve;
10use std::collections::{HashMap, HashSet};
11
12use std::sync::Arc;
13
14use anyhow::{Context as _, Result, bail};
15use derive_builder::Builder;
16pub use qdrant_client;
17use qdrant_client::qdrant::{self, SparseVectorParamsBuilder, SparseVectorsConfigBuilder};
18
19use swiftide_core::indexing::{EmbeddedField, TextNode};
20
21const DEFAULT_COLLECTION_NAME: &str = "swiftide";
22const DEFAULT_QDRANT_URL: &str = "http://localhost:6334";
23const DEFAULT_BATCH_SIZE: usize = 50;
24
25#[derive(Builder, Clone)]
32#[builder(
33 pattern = "owned",
34 setter(strip_option),
35 build_fn(error = "anyhow::Error")
36)]
37pub struct Qdrant {
38 #[builder(setter(into), default = "self.default_client()?")]
43 #[allow(clippy::missing_fields_in_debug)]
44 client: Arc<qdrant_client::Qdrant>,
45 #[builder(default = "DEFAULT_COLLECTION_NAME.to_string()")]
47 #[builder(setter(into))]
48 collection_name: String,
49 vector_size: u64,
51 #[builder(default = "Distance::Cosine")]
52 vector_distance: Distance,
54 #[builder(default = "Some(DEFAULT_BATCH_SIZE)")]
56 batch_size: Option<usize>,
57 #[builder(private, default = "Self::default_vectors()")]
58 pub(crate) vectors: HashMap<EmbeddedField, VectorConfig>,
59 #[builder(private, default)]
60 pub(crate) sparse_vectors: HashMap<EmbeddedField, SparseVectorConfig>,
61}
62
63impl Qdrant {
64 pub fn builder() -> QdrantBuilder {
66 QdrantBuilder::default()
67 }
68
69 pub fn try_from_url(url: impl AsRef<str>) -> Result<QdrantBuilder> {
86 Ok(QdrantBuilder::default().client(
87 qdrant_client::Qdrant::from_url(url.as_ref())
88 .api_key(std::env::var("QDRANT_API_KEY"))
89 .build()?,
90 ))
91 }
92
93 pub async fn create_index_if_not_exists(&self) -> Result<()> {
106 let collection_name = &self.collection_name;
107
108 tracing::debug!("Checking if collection {collection_name} exists");
109 if self.client.collection_exists(collection_name).await? {
110 tracing::warn!(
111 "Collection {collection_name} exists, skipping collection creation; if vector configurations have not changed, you can ignore this message"
112 );
113 return Ok(());
114 }
115
116 let vectors_config = self.create_vectors_config()?;
117 tracing::debug!(?vectors_config, "Adding vectors config");
118
119 let mut collection =
120 qdrant::CreateCollectionBuilder::new(collection_name).vectors_config(vectors_config);
121
122 if let Some(sparse_vectors_config) = self.create_sparse_vectors_config() {
123 tracing::debug!(?sparse_vectors_config, "Adding sparse vectors config");
124 collection = collection.sparse_vectors_config(sparse_vectors_config);
125 }
126
127 tracing::info!("Creating collection {collection_name}");
128 self.client.create_collection(collection).await?;
129 Ok(())
130 }
131
132 fn create_vectors_config(&self) -> Result<qdrant_client::qdrant::vectors_config::Config> {
133 if self.vectors.is_empty() {
134 bail!("No configured vectors");
135 } else if self.vectors.len() == 1 && self.sparse_vectors.is_empty() {
136 let config = self
137 .vectors
138 .values()
139 .next()
140 .context("Has one vector config")?;
141 let vector_params = self.create_vector_params(config);
142 return Ok(qdrant::vectors_config::Config::Params(vector_params));
143 }
144 let mut map = HashMap::<String, qdrant::VectorParams>::default();
145 for (embedded_field, config) in &self.vectors {
146 let vector_name = embedded_field.to_string();
147 let vector_params = self.create_vector_params(config);
148
149 map.insert(vector_name, vector_params);
150 }
151
152 Ok(qdrant::vectors_config::Config::ParamsMap(
153 qdrant::VectorParamsMap { map },
154 ))
155 }
156
157 fn create_sparse_vectors_config(&self) -> Option<qdrant::SparseVectorConfig> {
158 if self.sparse_vectors.is_empty() {
159 return None;
160 }
161 let mut sparse_vectors_config = SparseVectorsConfigBuilder::default();
162 for embedded_field in self.sparse_vectors.keys() {
163 let vector_name = format!("{embedded_field}_sparse");
164 let vector_params = SparseVectorParamsBuilder::default();
165 sparse_vectors_config.add_named_vector_params(vector_name, vector_params);
166 }
167
168 Some(sparse_vectors_config.into())
169 }
170
171 fn create_vector_params(&self, config: &VectorConfig) -> qdrant::VectorParams {
172 let size = config.vector_size.unwrap_or(self.vector_size);
173 let distance = config.distance.unwrap_or(self.vector_distance);
174
175 tracing::debug!(
176 "Creating vector params: size={}, distance={:?}",
177 size,
178 distance
179 );
180 qdrant::VectorParamsBuilder::new(size, distance).build()
181 }
182
183 pub fn client(&self) -> &Arc<qdrant_client::Qdrant> {
185 &self.client
186 }
187}
188
189impl QdrantBuilder {
190 #[allow(clippy::unused_self)]
191 fn default_client(&self) -> Result<Arc<qdrant_client::Qdrant>> {
192 let client = qdrant_client::Qdrant::from_url(
193 &std::env::var("QDRANT_URL").unwrap_or(DEFAULT_QDRANT_URL.to_string()),
194 )
195 .api_key(std::env::var("QDRANT_API_KEY"))
196 .build()
197 .context("Could not build default qdrant client")?;
198
199 Ok(Arc::new(client))
200 }
201
202 #[must_use]
209 pub fn with_vector(mut self, vector: impl Into<VectorConfig>) -> QdrantBuilder {
210 if self.vectors.is_none() {
211 self = self.vectors(HashMap::default());
212 }
213 let vector = vector.into();
214 if let Some(vectors) = self.vectors.as_mut()
215 && let Some(overridden_vector) = vectors.insert(vector.embedded_field.clone(), vector)
216 {
217 tracing::warn!(
218 "Overriding named vector config: {}",
219 overridden_vector.embedded_field
220 );
221 }
222 self
223 }
224
225 #[must_use]
227 pub fn with_sparse_vector(mut self, vector: impl Into<SparseVectorConfig>) -> QdrantBuilder {
228 if self.sparse_vectors.is_none() {
229 self = self.sparse_vectors(HashMap::default());
230 }
231 let vector = vector.into();
232 if let Some(vectors) = self.sparse_vectors.as_mut()
233 && let Some(overridden_vector) = vectors.insert(vector.embedded_field.clone(), vector)
234 {
235 tracing::warn!(
236 "Overriding named vector config: {}",
237 overridden_vector.embedded_field
238 );
239 }
240 self
241 }
242
243 fn default_vectors() -> HashMap<EmbeddedField, VectorConfig> {
244 HashMap::from([(EmbeddedField::default(), VectorConfig::default())])
245 }
246}
247
248#[allow(clippy::missing_fields_in_debug)]
249impl std::fmt::Debug for Qdrant {
250 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251 f.debug_struct("Qdrant")
252 .field("collection_name", &self.collection_name)
253 .field("vector_size", &self.vector_size)
254 .field("batch_size", &self.batch_size)
255 .finish()
256 }
257}
258
259#[derive(Clone, Builder, Default)]
263pub struct VectorConfig {
264 #[builder(default)]
266 pub(super) embedded_field: EmbeddedField,
267 #[builder(setter(into, strip_option), default)]
271 vector_size: Option<u64>,
272 #[builder(setter(into, strip_option), default)]
276 distance: Option<qdrant::Distance>,
277}
278
279impl VectorConfig {
280 pub fn builder() -> VectorConfigBuilder {
281 VectorConfigBuilder::default()
282 }
283}
284
285impl From<EmbeddedField> for VectorConfig {
286 fn from(value: EmbeddedField) -> Self {
287 Self {
288 embedded_field: value,
289 ..Default::default()
290 }
291 }
292}
293
294#[derive(Clone, Builder, Default)]
296pub struct SparseVectorConfig {
297 embedded_field: EmbeddedField,
298}
299
300impl From<EmbeddedField> for SparseVectorConfig {
301 fn from(value: EmbeddedField) -> Self {
302 Self {
303 embedded_field: value,
304 }
305 }
306}
307
308pub type Distance = qdrant::Distance;
309
310struct NodeWithVectors<'a> {
312 node: &'a TextNode,
313 vector_fields: HashSet<&'a EmbeddedField>,
314}
315
316impl<'a> NodeWithVectors<'a> {
317 pub fn new(node: &'a TextNode, vector_fields: HashSet<&'a EmbeddedField>) -> Self {
318 Self {
319 node,
320 vector_fields,
321 }
322 }
323}